convert-ggml-to-pth.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. # Author: github.com/ductai199x
  2. import argparse
  3. import os
  4. import struct
  5. import numpy as np
  6. import torch
  7. from numba import njit
  8. from tqdm.auto import tqdm
  9. def read_header(fin):
  10. values = struct.unpack("i" * 9, fin.read(4 * 9))
  11. _, _, vocab_size, dim, multiple_of, n_heads, n_layers, rot, ftype = values
  12. return {
  13. "vocab_size": vocab_size,
  14. "dim": dim,
  15. "multiple_of": multiple_of,
  16. "n_heads": n_heads,
  17. "n_layers": n_layers,
  18. }, ftype
  19. def read_tokens(fin, vocab_size):
  20. tokens = []
  21. for _ in range(vocab_size):
  22. text_len = struct.unpack("i", fin.read(4))[0]
  23. text_bytes = fin.read(text_len)
  24. try:
  25. text = text_bytes.decode()
  26. except UnicodeDecodeError:
  27. text = text_bytes.decode(errors="replace")
  28. score = struct.unpack("f", fin.read(4))[0]
  29. tokens.append((text, score))
  30. return tokens
  31. @njit
  32. def dequantize_weights_numba(fin_data, n_rows, n_cols):
  33. qk = 32
  34. nb = n_cols // qk
  35. bs = 4 + (qk // 2)
  36. weights = np.zeros((n_rows, n_cols), dtype=np.float32)
  37. data_pos = 0
  38. for row in range(n_rows):
  39. for block in range(nb):
  40. d = np.frombuffer(fin_data[data_pos : data_pos + 4], dtype=np.float32)[0]
  41. data_pos += 4
  42. packed_values = fin_data[data_pos : data_pos + (qk // 2)]
  43. data_pos += qk // 2
  44. for i in range(qk // 2):
  45. packed_value = packed_values[i]
  46. v0 = np.float32((packed_value & 0b00001111) - 8) * d
  47. v1 = np.float32((packed_value >> 4) - 8) * d
  48. weights[row, block * qk + 2 * i] = v0
  49. weights[row, block * qk + 2 * i + 1] = v1
  50. return weights
  51. def dequantize_weights(fin, n_rows, n_cols):
  52. qk = 32
  53. nb = n_cols // qk
  54. data_size = n_rows * n_cols // 2 + n_rows * nb * 4
  55. fin_data = fin.read(data_size)
  56. return dequantize_weights_numba(fin_data, n_rows, n_cols)
  57. def read_variables(fin):
  58. model = {}
  59. pbar = tqdm(total=os.path.getsize(fin.name), unit="B", unit_scale=True, desc="Reading variables")
  60. while True:
  61. start_pos = fin.tell()
  62. try:
  63. n_dims, name_length, ftype_cur = struct.unpack("iii", fin.read(4 * 3))
  64. except struct.error:
  65. break
  66. shape = tuple(struct.unpack("i" * n_dims, fin.read(4 * n_dims)))
  67. shape = shape[::-1]
  68. name = fin.read(name_length).decode()
  69. # ensure tensor data is aligned
  70. tensor_data_offset = fin.tell()
  71. tensor_data_offset = (tensor_data_offset + 31) & -32
  72. fin.seek(tensor_data_offset)
  73. if ftype_cur == 2:
  74. # 4-bit quantized weights
  75. dtype = np.uint8
  76. data = dequantize_weights(fin, shape[0], shape[1])
  77. data = data.reshape(shape)
  78. elif ftype_cur == 0:
  79. dtype = np.float32
  80. data_size = np.prod(shape)
  81. data = np.fromfile(fin, dtype=dtype, count=data_size).reshape(shape)
  82. elif ftype_cur == 1:
  83. dtype = np.float16
  84. data_size = np.prod(shape)
  85. data = np.fromfile(fin, dtype=dtype, count=data_size).reshape(shape)
  86. model[name] = torch.tensor(data, dtype=torch.float32 if dtype == np.float32 else torch.float16)
  87. pbar.update(fin.tell() - start_pos)
  88. return model
  89. def convert_to_hf_format(model, hparams):
  90. # This works for llama 7B, need to test with other models
  91. n_layers = hparams["n_layers"]
  92. n_heads = hparams["n_heads"]
  93. dim = hparams["dim"]
  94. dims_per_head = dim // n_heads
  95. base = 10000.0
  96. inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
  97. # permute for sliced rotary
  98. def permute(w):
  99. return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
  100. state_dict = {}
  101. for layer_i in range(n_layers):
  102. state_dict.update(
  103. {
  104. f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
  105. model[f"layers.{layer_i}.attention.wq.weight"]
  106. ),
  107. f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
  108. model[f"layers.{layer_i}.attention.wk.weight"]
  109. ),
  110. f"model.layers.{layer_i}.self_attn.v_proj.weight": model[
  111. f"layers.{layer_i}.attention.wv.weight"
  112. ],
  113. f"model.layers.{layer_i}.self_attn.o_proj.weight": model[
  114. f"layers.{layer_i}.attention.wo.weight"
  115. ],
  116. f"model.layers.{layer_i}.mlp.gate_proj.weight": model[
  117. f"layers.{layer_i}.feed_forward.w1.weight"
  118. ],
  119. f"model.layers.{layer_i}.mlp.down_proj.weight": model[
  120. f"layers.{layer_i}.feed_forward.w2.weight"
  121. ],
  122. f"model.layers.{layer_i}.mlp.up_proj.weight": model[
  123. f"layers.{layer_i}.feed_forward.w3.weight"
  124. ],
  125. f"model.layers.{layer_i}.input_layernorm.weight": model[
  126. f"layers.{layer_i}.attention_norm.weight"
  127. ],
  128. f"model.layers.{layer_i}.post_attention_layernorm.weight": model[
  129. f"layers.{layer_i}.ffn_norm.weight"
  130. ],
  131. }
  132. )
  133. state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
  134. state_dict.update(
  135. {
  136. "model.embed_tokens.weight": model["tok_embeddings.weight"],
  137. "model.norm.weight": model["norm.weight"],
  138. "lm_head.weight": model["output.weight"],
  139. }
  140. )
  141. return state_dict
  142. def chat(model, hparams, llama_dir):
  143. from transformers import (GenerationConfig, LlamaForCausalLM,
  144. LlamaTokenizer, StoppingCriteria,
  145. StoppingCriteriaList)
  146. from transformers.models.llama.configuration_llama import LlamaConfig
  147. class StoppingCriteriaSub(StoppingCriteria):
  148. def __init__(self):
  149. super().__init__()
  150. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops=[]):
  151. print(tokenizer.decode(input_ids[0]), end="", flush=True)
  152. if input_ids[0][-1] == 13:
  153. return True
  154. return False
  155. config = LlamaConfig(
  156. vocab_size=hparams["vocab_size"],
  157. dim=hparams["dim"],
  158. num_hidden_layers=hparams["n_layers"],
  159. num_attention_heads=hparams["n_heads"],
  160. )
  161. llama = LlamaForCausalLM(config=config)
  162. llama.load_state_dict(state_dict=model, strict=True)
  163. tokenizer = LlamaTokenizer.from_pretrained(llama_dir)
  164. device = torch.device("cpu")
  165. llama = llama.to(device)
  166. ctx = """You are AI.
  167. This is a dialog, where User interacts with AI. AI is helpful, kind, obedient, honest, respectful, direct, concise, should try to protect User's privacy, and knows its own limits. Also, AI must answer User and AI cannot stop the conversation by itself.
  168. User: Hello, AI.
  169. AI: Hello! How can I assist you today?
  170. """
  171. print(ctx.rstrip("\n"))
  172. while True:
  173. print("-" * 60)
  174. prompt = input("User: ")
  175. if ctx != "":
  176. ctx = f"{ctx}User: {prompt}\n"
  177. else:
  178. ctx = f"{prompt}\nAI:"
  179. ctx = (ctx[-1920:]) if len(ctx) >= 2048 else ctx
  180. print("-" * 60)
  181. if len(ctx.strip()) > 0:
  182. input_ids = tokenizer(ctx, return_tensors="pt")["input_ids"].to(device)
  183. generation_config = GenerationConfig(
  184. temperature=0.8,
  185. top_p=0.95,
  186. top_k=50,
  187. repetition_penalty=1.1764,
  188. )
  189. with torch.no_grad():
  190. generation_output = llama.generate(
  191. input_ids=input_ids,
  192. generation_config=generation_config,
  193. return_dict_in_generate=True,
  194. output_scores=True,
  195. max_length=2048,
  196. do_sample=True,
  197. stopping_criteria=StoppingCriteriaList([StoppingCriteriaSub()]),
  198. )
  199. s = generation_output.sequences[0]
  200. decoded = tokenizer.decode(s)
  201. ctx = f"{decoded}\n"
  202. def main():
  203. parser = argparse.ArgumentParser()
  204. parser.add_argument(
  205. "--input_dir", "-i", type=str, required=True, help="The input directory containing the ggml files."
  206. )
  207. parser.add_argument(
  208. "--prefix",
  209. "-p",
  210. type=str,
  211. required=True,
  212. help="The prefix of the ggml files (ggml-model-f16 or ggml-model-q4_0).",
  213. )
  214. parser.add_argument(
  215. "--hf",
  216. action="store_true",
  217. help="Whether to save the model in the Hugging Face format. (default: False)",
  218. )
  219. parser.add_argument(
  220. "--chat", "-c", action="store_true", help="Whether to open a chat with the model. (default: False)"
  221. )
  222. args = parser.parse_args()
  223. llama_dir = os.path.abspath(f"{args.input_dir}/../")
  224. ggml_files = sorted(
  225. [f"{args.input_dir}/{f}" for f in os.listdir(args.input_dir) if f.startswith(args.prefix)]
  226. )
  227. fin = open(ggml_files[0], "rb")
  228. hparams, ftype = read_header(fin)
  229. tokens = read_tokens(fin, hparams["vocab_size"])
  230. model = read_variables(fin)
  231. for f in tqdm(ggml_files[1:]):
  232. fin = open(f, "rb")
  233. read_header(fin)
  234. read_tokens(fin, hparams["vocab_size"])
  235. model.update(read_variables(fin))
  236. if args.hf:
  237. model = convert_to_hf_format(model, hparams)
  238. pth_ckpt = {
  239. "state_dict": model,
  240. "hparams": hparams,
  241. "tokens": tokens,
  242. }
  243. torch.save(pth_ckpt, f"{args.input_dir}/{args.prefix}-to-torch.pth")
  244. if args.chat:
  245. if not args.hf:
  246. model = convert_to_hf_format(model, hparams)
  247. chat(model, hparams, llama_dir)
  248. if __name__ == "__main__":
  249. main()