convert-llama-hf-to-gguf.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. #!/usr/bin/env python3
  2. # HF llama --> gguf conversion
  3. import gguf
  4. import os
  5. import sys
  6. import struct
  7. import json
  8. import numpy as np
  9. import torch
  10. import argparse
  11. from typing import Any, List, Optional, TypeAlias
  12. from pathlib import Path
  13. from sentencepiece import SentencePieceProcessor
  14. #NDArray = np.ndarray[Any, Any]
  15. # compatible with python < 3.9
  16. NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
  17. # reverse HF permute back to original pth layout
  18. # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
  19. def reverse_hf_permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
  20. if n_kv_head is not None and n_head != n_kv_head:
  21. n_head //= n_kv_head
  22. return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  23. .swapaxes(1, 2)
  24. .reshape(weights.shape))
  25. def count_model_parts(dir_model: str) -> int:
  26. num_parts = 0
  27. for filename in os.listdir(dir_model):
  28. if filename.startswith("pytorch_model-"):
  29. num_parts += 1
  30. if num_parts > 0:
  31. print("gguf: found " + str(num_parts) + " model parts")
  32. return num_parts
  33. def parse_args() -> argparse.Namespace:
  34. parser = argparse.ArgumentParser(description="Convert a HuggingFace LLaMA model to a GGML compatible file")
  35. parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
  36. parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
  37. parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.bin)")
  38. parser.add_argument("ftype", type=int, choices=[0, 1], help="output format - use 0 for float32, 1 for float16", default = 1)
  39. return parser.parse_args()
  40. args = parse_args()
  41. dir_model = args.model
  42. ftype = args.ftype
  43. if not dir_model.is_dir():
  44. print(f'Error: {args.model} is not a directory', file = sys.stderr)
  45. sys.exit(1)
  46. # possible tensor data types
  47. # ftype == 0 -> float32
  48. # ftype == 1 -> float16
  49. # map from ftype to string
  50. ftype_str = ["f32", "f16"]
  51. if args.outfile is not None:
  52. fname_out = args.outfile
  53. else:
  54. # output in the same directory as the model by default
  55. fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
  56. print("gguf: loading model "+dir_model.name)
  57. with open(dir_model / "config.json", "r", encoding="utf-8") as f:
  58. hparams = json.load(f)
  59. if hparams["architectures"][0] != "LlamaForCausalLM":
  60. print("Model architecture not supported: " + hparams["architectures"][0])
  61. sys.exit()
  62. # get number of model parts
  63. num_parts = count_model_parts(dir_model)
  64. ARCH=gguf.MODEL_ARCH.LLAMA
  65. gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
  66. print("gguf: get model metadata")
  67. block_count = hparams["num_hidden_layers"]
  68. head_count = hparams["num_attention_heads"]
  69. if "num_key_value_heads" in hparams:
  70. head_count_kv = hparams["num_key_value_heads"]
  71. else:
  72. head_count_kv = head_count
  73. if "_name_or_path" in hparams:
  74. hf_repo = hparams["_name_or_path"]
  75. else:
  76. hf_repo = ""
  77. if "max_sequence_length" in hparams:
  78. ctx_length = hparams["max_sequence_length"]
  79. elif "max_position_embeddings" in hparams:
  80. ctx_length = hparams["max_position_embeddings"]
  81. else:
  82. print("gguf: can not find ctx length parameter.")
  83. sys.exit()
  84. gguf_writer.add_name(dir_model.name)
  85. gguf_writer.add_source_hf_repo(hf_repo)
  86. gguf_writer.add_tensor_data_layout("Meta AI original pth")
  87. gguf_writer.add_context_length(ctx_length)
  88. gguf_writer.add_embedding_length(hparams["hidden_size"])
  89. gguf_writer.add_block_count(block_count)
  90. gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
  91. gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])
  92. gguf_writer.add_head_count(head_count)
  93. gguf_writer.add_head_count_kv(head_count_kv)
  94. gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
  95. if "rope_scaling" in hparams and hparams["rope_scaling"] != None and "factor" in hparams["rope_scaling"]:
  96. if "type" in hparams["rope_scaling"]:
  97. if hparams["rope_scaling"]["type"] == "linear":
  98. gguf_writer.add_rope_scale_linear(hparams["rope_scaling"]["factor"])
  99. # TOKENIZATION
  100. print("gguf: get tokenizer metadata")
  101. tokens: List[bytes] = []
  102. scores: List[float] = []
  103. toktypes: List[int] = []
  104. tokenizer_model_file = dir_model / 'tokenizer.model'
  105. if not tokenizer_model_file.is_file():
  106. print(f'Error: Missing {tokenizer_model_file}', file = sys.stderr)
  107. sys.exit(1)
  108. # vocab type sentencepiece
  109. print("gguf: get sentencepiece tokenizer vocab, scores and token types")
  110. tokenizer = SentencePieceProcessor(str(tokenizer_model_file))
  111. for i in range(tokenizer.vocab_size()):
  112. text: bytes
  113. score: float
  114. piece = tokenizer.id_to_piece(i)
  115. text = piece.encode("utf-8")
  116. score = tokenizer.get_score(i)
  117. toktype = 1 # defualt to normal token type
  118. if tokenizer.is_unknown(i):
  119. toktype = 2
  120. if tokenizer.is_control(i):
  121. toktype = 3
  122. # toktype = 4 is user-defined = tokens from added_tokens.json
  123. if tokenizer.is_unused(i):
  124. toktype = 5
  125. if tokenizer.is_byte(i):
  126. toktype = 6
  127. tokens.append(text)
  128. scores.append(score)
  129. toktypes.append(toktype)
  130. added_tokens_file = dir_model / 'added_tokens.json'
  131. if added_tokens_file.is_file():
  132. with open(added_tokens_file, "r", encoding="utf-8") as f:
  133. addtokens_json = json.load(f)
  134. print("gguf: get added tokens")
  135. for key in addtokens_json:
  136. tokens.append( key.encode("utf-8") )
  137. scores.append(-1000.0)
  138. toktypes.append(4) # user-defined token type
  139. gguf_writer.add_tokenizer_model("llama")
  140. gguf_writer.add_token_list(tokens)
  141. gguf_writer.add_token_scores(scores)
  142. gguf_writer.add_token_types(toktypes)
  143. special_vocab = gguf.SpecialVocab(dir_model)
  144. special_vocab.add_to_gguf(gguf_writer)
  145. # TENSORS
  146. tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
  147. # tensor info
  148. print("gguf: get tensor metadata")
  149. if num_parts == 0:
  150. part_names = iter(("pytorch_model.bin",))
  151. else:
  152. part_names = (
  153. f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
  154. )
  155. for part_name in part_names:
  156. if args.vocab_only:
  157. break
  158. print("gguf: loading model part '" + part_name + "'")
  159. model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
  160. for name in model_part.keys():
  161. data = model_part[name]
  162. # we don't need these
  163. if name.endswith(".rotary_emb.inv_freq"):
  164. continue
  165. old_dtype = data.dtype
  166. # convert any unsupported data types to float32
  167. if data.dtype != torch.float16 and data.dtype != torch.float32:
  168. data = data.to(torch.float32)
  169. data = data.squeeze().numpy()
  170. # reverse permute these
  171. if name.endswith(".q_proj.weight"):
  172. data = reverse_hf_permute(data, head_count)
  173. if name.endswith(".k_proj.weight"):
  174. data = reverse_hf_permute(data, head_count, head_count_kv)
  175. # map tensor names
  176. new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
  177. if new_name is None:
  178. print("Can not map tensor '" + name + "'")
  179. sys.exit()
  180. n_dims = len(data.shape)
  181. data_dtype = data.dtype
  182. # if f32 desired, convert any float16 to float32
  183. if ftype == 0 and data_dtype == np.float16:
  184. data = data.astype(np.float32)
  185. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  186. if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  187. data = data.astype(np.float32)
  188. # if f16 desired, convert any float32 2-dim weight tensors to float16
  189. if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  190. data = data.astype(np.float16)
  191. print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
  192. gguf_writer.add_tensor(new_name, data)
  193. print("gguf: write header")
  194. gguf_writer.write_header_to_file()
  195. print("gguf: write metadata")
  196. gguf_writer.write_kv_data_to_file()
  197. if not args.vocab_only:
  198. print("gguf: write tensors")
  199. gguf_writer.write_tensors_to_file()
  200. gguf_writer.close()
  201. print(f"gguf: model successfully exported to '{fname_out}'")
  202. print("")