convert-pth-to-ggml.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # Convert a LLaMA model checkpoint to a ggml compatible file
  2. #
  3. # Load the model using Torch
  4. # Iterate over all variables and write them to a binary file.
  5. #
  6. # For each variable, write the following:
  7. # - Number of dimensions (int)
  8. # - Name length (int)
  9. # - Dimensions (int[n_dims])
  10. # - Name (char[name_length])
  11. # - Data (float[n_dims])
  12. #
  13. # By default, the bigger matrices are converted to 16-bit floats.
  14. # This can be disabled by adding the "use-f32" CLI argument.
  15. #
  16. # At the start of the ggml file we write the model parameters
  17. # and vocabulary.
  18. #
  19. import sys
  20. import json
  21. import struct
  22. import numpy as np
  23. import torch
  24. from sentencepiece import SentencePieceProcessor
  25. if len(sys.argv) < 3:
  26. print("Usage: convert-ckpt-to-ggml.py dir-model ftype\n")
  27. print(" ftype == 0 -> float32")
  28. print(" ftype == 1 -> float16")
  29. sys.exit(1)
  30. # output in the same directory as the model
  31. dir_model = sys.argv[1]
  32. fname_hparams = sys.argv[1] + "/params.json"
  33. fname_tokenizer = sys.argv[1] + "/../tokenizer.model"
  34. def get_n_parts(dim):
  35. if dim == 4096:
  36. return 1
  37. elif dim == 5120:
  38. return 2
  39. elif dim == 6656:
  40. return 4
  41. elif dim == 8192:
  42. return 8
  43. else:
  44. print("Invalid dim: " + str(dim))
  45. sys.exit(1)
  46. # possible data types
  47. # ftype == 0 -> float32
  48. # ftype == 1 -> float16
  49. #
  50. # map from ftype to string
  51. ftype_str = ["f32", "f16"]
  52. ftype = 1
  53. if len(sys.argv) > 2:
  54. ftype = int(sys.argv[2])
  55. if ftype < 0 or ftype > 1:
  56. print("Invalid ftype: " + str(ftype))
  57. sys.exit(1)
  58. fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
  59. with open(fname_hparams, "r") as f:
  60. hparams = json.load(f)
  61. tokenizer = SentencePieceProcessor(fname_tokenizer)
  62. hparams.update({"vocab_size": tokenizer.vocab_size()})
  63. n_parts = get_n_parts(hparams["dim"])
  64. print(hparams)
  65. print('n_parts = ', n_parts)
  66. for p in range(n_parts):
  67. print('Processing part ', p)
  68. #fname_model = sys.argv[1] + "/consolidated.00.pth"
  69. fname_model = sys.argv[1] + "/consolidated.0" + str(p) + ".pth"
  70. fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
  71. if (p > 0):
  72. fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + "." + str(p)
  73. model = torch.load(fname_model, map_location="cpu")
  74. fout = open(fname_out, "wb")
  75. fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
  76. fout.write(struct.pack("i", hparams["vocab_size"]))
  77. fout.write(struct.pack("i", hparams["dim"]))
  78. fout.write(struct.pack("i", hparams["multiple_of"]))
  79. fout.write(struct.pack("i", hparams["n_heads"]))
  80. fout.write(struct.pack("i", hparams["n_layers"]))
  81. fout.write(struct.pack("i", hparams["dim"] // hparams["n_heads"])) # rot (obsolete)
  82. fout.write(struct.pack("i", ftype))
  83. # Is this correct??
  84. for i in range(32000):
  85. # TODO: this is probably wrong - not sure how this tokenizer works
  86. text = tokenizer.decode([29889, i]).encode('utf-8')
  87. # remove the first byte (it's always '.')
  88. text = text[1:]
  89. fout.write(struct.pack("i", len(text)))
  90. fout.write(text)
  91. for k, v in model.items():
  92. name = k
  93. shape = v.shape
  94. # skip layers.X.attention.inner_attention.rope.freqs
  95. if name[-5:] == "freqs":
  96. continue
  97. print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype)
  98. #data = tf.train.load_variable(dir_model, name).squeeze()
  99. data = v.numpy().squeeze()
  100. n_dims = len(data.shape);
  101. # for efficiency - transpose some matrices
  102. # "model/h.*/attn/c_attn/w"
  103. # "model/h.*/attn/c_proj/w"
  104. # "model/h.*/mlp/c_fc/w"
  105. # "model/h.*/mlp/c_proj/w"
  106. #if name[-14:] == "/attn/c_attn/w" or \
  107. # name[-14:] == "/attn/c_proj/w" or \
  108. # name[-11:] == "/mlp/c_fc/w" or \
  109. # name[-13:] == "/mlp/c_proj/w":
  110. # print(" Transposing")
  111. # data = data.transpose()
  112. dshape = data.shape
  113. # default type is fp16
  114. ftype_cur = 1
  115. if ftype == 0 or n_dims == 1:
  116. print(" Converting to float32")
  117. data = data.astype(np.float32)
  118. ftype_cur = 0
  119. # header
  120. sname = name.encode('utf-8')
  121. fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur))
  122. for i in range(n_dims):
  123. fout.write(struct.pack("i", dshape[n_dims - 1 - i]))
  124. fout.write(sname);
  125. # data
  126. data.tofile(fout)
  127. # I hope this deallocates the memory ..
  128. model = None
  129. fout.close()
  130. print("Done. Output file: " + fname_out + ", (part ", p, ")")
  131. print("")