convert_pt_to_hf.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # convert the https://huggingface.co/novateur/WavTokenizer-large-speech-75token to HF format
  2. # the goal is to be able to reuse the convert_hf_to_gguf.py after that to create a GGUF file with the WavTokenizer decoder
  3. #
  4. # TODO: this script is LLM-generated and probably very inefficient and should be rewritten
  5. import torch
  6. import json
  7. import os
  8. import sys
  9. import re
  10. from safetensors.torch import save_file
  11. # default
  12. model_path = './model.pt';
  13. # read from CLI
  14. if len(sys.argv) > 1:
  15. model_path = sys.argv[1]
  16. # get the directory of the input model
  17. path_dst = os.path.dirname(model_path)
  18. print(f"Loading model from {model_path}")
  19. model = torch.load(model_path, map_location='cpu')
  20. #print(model)
  21. # print all keys
  22. for key in model.keys():
  23. print(key)
  24. if key == 'hyper_parameters':
  25. #print(model[key])
  26. # dump as json pretty
  27. print(json.dumps(model[key], indent=4))
  28. #if key != 'state_dict' and key != 'optimizer_states':
  29. # print(model[key])
  30. # Check if the loaded model is a state_dict or a model instance
  31. if isinstance(model, torch.nn.Module):
  32. state_dict = model.state_dict()
  33. else:
  34. state_dict = model
  35. # Print the structure of the state_dict to understand its format
  36. print("State dictionary keys:")
  37. for key in state_dict.keys():
  38. print(key)
  39. # Ensure the state_dict is flat and contains only torch.Tensor objects
  40. def flatten_state_dict(state_dict, parent_key='', sep='.'):
  41. items = []
  42. items_new = []
  43. for k, v in state_dict.items():
  44. new_key = f"{parent_key}{sep}{k}" if parent_key else k
  45. if isinstance(v, torch.Tensor):
  46. items.append((new_key, v))
  47. elif isinstance(v, dict):
  48. items.extend(flatten_state_dict(v, new_key, sep=sep).items())
  49. return dict(items)
  50. size_total_mb = 0
  51. for key, value in list(items):
  52. # keep only what we need for inference
  53. if not key.startswith('state_dict.feature_extractor.encodec.quantizer.') and \
  54. not key.startswith('state_dict.backbone.') and \
  55. not key.startswith('state_dict.head.out'):
  56. print('Skipping key: ', key)
  57. continue
  58. new_key = key
  59. new_key = new_key.replace('state_dict.', '')
  60. new_key = new_key.replace('pos_net', 'posnet')
  61. # check if matches "backbone.posnet.%d.bias" or "backbone.posnet.%d.weight"
  62. if new_key.startswith("backbone.posnet."):
  63. match = re.match(r"backbone\.posnet\.(\d+)\.(bias|weight)", new_key)
  64. if match:
  65. new_key = f"backbone.posnet.{match.group(1)}.norm.{match.group(2)}"
  66. # "feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed" -> "backbone.embedding.weight"
  67. if new_key == "feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed":
  68. new_key = "backbone.embedding.weight"
  69. # these are the only rows used
  70. # ref: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/wav_tokenizer/audio_codec.py#L100
  71. if new_key.endswith("norm.scale.weight"):
  72. new_key = new_key.replace("norm.scale.weight", "norm.weight")
  73. value = value[0]
  74. if new_key.endswith("norm.shift.weight"):
  75. new_key = new_key.replace("norm.shift.weight", "norm.bias")
  76. value = value[0]
  77. if new_key.endswith("gamma"):
  78. new_key = new_key.replace("gamma", "gamma.weight")
  79. # convert from 1D [768] to 2D [768, 1] so that ggml_add can broadcast the bias
  80. if (new_key.endswith("norm.weight") or new_key.endswith("norm1.weight") or new_key.endswith("norm2.weight") or new_key.endswith(".bias")) and (new_key.startswith("backbone.posnet") or new_key.startswith("backbone.embed.bias")):
  81. value = value.unsqueeze(1)
  82. if new_key.endswith("dwconv.bias"):
  83. value = value.unsqueeze(1)
  84. size_mb = value.element_size() * value.nelement() / (1024 * 1024)
  85. print(f"{size_mb:8.2f} MB - {new_key}: {value.shape}")
  86. size_total_mb += size_mb
  87. #print(key, '->', new_key, ': ', value)
  88. #print(key, '->', new_key)
  89. items_new.append((new_key, value))
  90. print(f"Total size: {size_total_mb:8.2f} MB")
  91. return dict(items_new)
  92. flattened_state_dict = flatten_state_dict(state_dict)
  93. # Convert the model to the safetensors format
  94. output_path = path_dst + '/model.safetensors'
  95. save_file(flattened_state_dict, output_path)
  96. print(f"Model has been successfully converted and saved to {output_path}")
  97. # Calculate the total size of the .safetensors file
  98. total_size = os.path.getsize(output_path)
  99. # Create the weight map
  100. weight_map = {
  101. "model.safetensors": ["*"] # Assuming all weights are in one file
  102. }
  103. # Create metadata for the index.json file
  104. metadata = {
  105. "total_size": total_size,
  106. "weight_map": weight_map
  107. }
  108. # Save the metadata to index.json
  109. index_path = path_dst + '/index.json'
  110. with open(index_path, 'w') as f:
  111. json.dump(metadata, f, indent=4)
  112. print(f"Metadata has been saved to {index_path}")
  113. config = {
  114. "architectures": [
  115. "WavTokenizerDec"
  116. ],
  117. "hidden_size": 1282,
  118. "n_embd_features": 512,
  119. "n_ff": 2304,
  120. "vocab_size": 4096,
  121. "n_head": 1,
  122. "layer_norm_epsilon": 1e-6,
  123. "group_norm_epsilon": 1e-6,
  124. "group_norm_groups": 32,
  125. "max_position_embeddings": 8192, # ?
  126. "n_layer": 12,
  127. "posnet": {
  128. "n_embd": 768,
  129. "n_layer": 6
  130. },
  131. "convnext": {
  132. "n_embd": 768,
  133. "n_layer": 12
  134. },
  135. }
  136. with open(path_dst + '/config.json', 'w') as f:
  137. json.dump(config, f, indent=4)
  138. print(f"Config has been saved to {path_dst + 'config.json'}")