convert-gptq-to-ggml.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # Convert a GPTQ quantized LLaMA model to a ggml compatible file
  2. # Based on: https://github.com/qwopqwop200/GPTQ-for-LLaMa
  3. #
  4. import os
  5. import re
  6. import sys
  7. import json
  8. import struct
  9. import numpy as np
  10. import torch
  11. from sentencepiece import SentencePieceProcessor
  12. if len(sys.argv) != 4:
  13. print("Usage: convert-gptq-to-ggml.py llamaXXb-4bit.pt tokenizer.model out.bin\n")
  14. sys.exit(1)
  15. fname_model = sys.argv[1]
  16. fname_tokenizer = sys.argv[2]
  17. dir_out = sys.argv[3]
  18. model = torch.load(fname_model, map_location="cpu")
  19. n_vocab, n_embd = model['model.embed_tokens.weight'].shape
  20. n_layer = 1 + max(int(m.group(1)) for name in model
  21. if (m := re.match(r'model\.layers\.([0-9]+)', name)))
  22. # hardcoded:
  23. n_mult = 256
  24. n_head = {32: 32, 40: 40, 60: 52, 80: 64}[n_layer]
  25. tokenizer = SentencePieceProcessor(fname_tokenizer)
  26. assert tokenizer.vocab_size() == n_vocab
  27. fname_out = sys.argv[3]
  28. fout = open(fname_out, "wb")
  29. fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
  30. fout.write(struct.pack("i", n_vocab))
  31. fout.write(struct.pack("i", n_embd))
  32. fout.write(struct.pack("i", n_mult))
  33. fout.write(struct.pack("i", n_head))
  34. fout.write(struct.pack("i", n_layer))
  35. fout.write(struct.pack("i", n_embd // n_head)) # rot (obsolete)
  36. fout.write(struct.pack("i", 4))
  37. # This loop unchanged from convert-pth-to-ggml.py:
  38. for i in range(tokenizer.vocab_size()):
  39. if tokenizer.is_unknown(i):
  40. # "<unk>" token (translated as ??)
  41. text = " \u2047 ".encode("utf-8")
  42. fout.write(struct.pack("i", len(text)))
  43. fout.write(text)
  44. elif tokenizer.is_control(i):
  45. # "<s>"/"</s>" tokens
  46. fout.write(struct.pack("i", 0))
  47. elif tokenizer.is_byte(i):
  48. # "<U+XX>" tokens (which may be invalid UTF-8)
  49. piece = tokenizer.id_to_piece(i)
  50. if len(piece) != 6:
  51. print("Invalid token: " + piece)
  52. sys.exit(1)
  53. byte_value = int(piece[3:-1], 16)
  54. fout.write(struct.pack("i", 1))
  55. fout.write(struct.pack("B", byte_value))
  56. else:
  57. # normal token. Uses U+2581 (LOWER ONE EIGHTH BLOCK) to represent spaces.
  58. text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
  59. fout.write(struct.pack("i", len(text)))
  60. fout.write(text)
  61. def write_header(shape, dst_name, ftype_cur):
  62. sname = dst_name.encode('utf-8')
  63. fout.write(struct.pack("iii", len(shape), len(sname), ftype_cur))
  64. fout.write(struct.pack("i" * len(shape), *shape[::-1]))
  65. fout.write(sname)
  66. def convert_non_q4(src_name, dst_name):
  67. v = model[src_name]
  68. shape = v.shape
  69. print("Processing non-Q4 variable: " + src_name + " with shape: ", shape, " and type: ", v.dtype)
  70. if len(shape) == 1:
  71. print(" Converting to float32")
  72. v = v.to(torch.float32)
  73. ftype_cur = {torch.float16: 1, torch.float32: 0}[v.dtype]
  74. # header
  75. write_header(shape, dst_name, ftype_cur)
  76. # data
  77. v.numpy().tofile(fout)
  78. def convert_q4(src_name, dst_name, permute=False):
  79. zeros = model[f"{src_name}.zeros"].numpy()
  80. scales = model[f"{src_name}.scales"].numpy()
  81. bias = model[f"{src_name}.bias"].numpy()
  82. qweight = model[f"{src_name}.qweight"].numpy().T # transpose
  83. # Q4_1 does not support bias; good thing the bias is always all zeros.
  84. assert not np.any(bias)
  85. # Each int32 item is actually 8 int4 items packed together, and it's transposed.
  86. shape = (qweight.shape[0], qweight.shape[1] * 8)
  87. print("Processing Q4 variable: " + src_name + " with shape: ", shape)
  88. # The output format has the int4 weights in groups of 32 rather than 8.
  89. # It looks like this:
  90. # For each row:
  91. # For each group of 32 columns:
  92. # - addend (float32, 4 bytes)
  93. # - scale (float32, 4 bytes)
  94. # - weights (int4 * 32, 16 bytes)
  95. # Note that in the input, the scales and addends are shared between all
  96. # the columns in a row, so we end up wasting quite a bit of memory with
  97. # repeated scales and addends.
  98. addends = -zeros # flip sign
  99. # Since the output format is mixed between integers and floats, we have
  100. # to hackily view the floats as int32s just so numpy will let us
  101. # concatenate them.
  102. addends_view = addends.view(dtype=np.int32)
  103. scales_view = scales.view(dtype=np.int32)
  104. # Split into groups of 4 columns (i.e. 32 columns of quantized data):
  105. grouped = qweight.reshape([qweight.shape[0], qweight.shape[1] // 4, 4])
  106. # Repeat addends and scales:
  107. addends_rep = np.atleast_3d(addends_view).repeat(grouped.shape[1], axis=1)
  108. scales_rep = np.atleast_3d(scales_view).repeat(grouped.shape[1], axis=1)
  109. blob = np.concatenate([scales_rep, addends_rep, grouped], axis=2, casting='no')
  110. if permute:
  111. # Permute some rows to undo the permutation done by convert_llama_weights_to_hf.py.
  112. # This can be done after the above conversion because it doesn't affect column order/layout.
  113. blob = (blob.reshape(n_head, 2, shape[0] // n_head // 2, *blob.shape[1:])
  114. .swapaxes(1, 2)
  115. .reshape(blob.shape))
  116. # header
  117. write_header(shape, dst_name, 3) # ftype = Q4_1
  118. # data
  119. blob.tofile(fout)
  120. convert_non_q4("model.embed_tokens.weight", "tok_embeddings.weight")
  121. convert_non_q4("model.norm.weight", "norm.weight")
  122. convert_non_q4("lm_head.weight", "output.weight")
  123. for i in range(n_layer):
  124. convert_q4(f"model.layers.{i}.self_attn.q_proj", f"layers.{i}.attention.wq.weight", permute=True)
  125. convert_q4(f"model.layers.{i}.self_attn.k_proj", f"layers.{i}.attention.wk.weight", permute=True)
  126. convert_q4(f"model.layers.{i}.self_attn.v_proj", f"layers.{i}.attention.wv.weight")
  127. convert_q4(f"model.layers.{i}.self_attn.o_proj", f"layers.{i}.attention.wo.weight")
  128. convert_q4(f"model.layers.{i}.mlp.gate_proj", f"layers.{i}.feed_forward.w1.weight")
  129. convert_q4(f"model.layers.{i}.mlp.down_proj", f"layers.{i}.feed_forward.w2.weight")
  130. convert_q4(f"model.layers.{i}.mlp.up_proj", f"layers.{i}.feed_forward.w3.weight")
  131. convert_non_q4(f"model.layers.{i}.input_layernorm.weight", f"layers.{i}.attention_norm.weight")
  132. convert_non_q4(f"model.layers.{i}.post_attention_layernorm.weight", f"layers.{i}.ffn_norm.weight")
  133. fout.close()
  134. print("Done. Output file: " + fname_out)
  135. print("")