apply_awq.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. """
  2. Implements the AWQ for llama.cpp use cases.
  3. Original paper: https://arxiv.org/abs/2306.00978
  4. This code is based on versions of the AWQ implementation found in the following repositories:
  5. * https://github.com/mit-han-lab/llm-awq
  6. * https://github.com/casper-hansen/AutoAWQ
  7. """
  8. import os
  9. import torch
  10. import torch.nn as nn
  11. from transformers import AutoModelForCausalLM, AutoConfig
  12. from transformers.models.bloom.modeling_bloom import BloomGelu
  13. from transformers.models.llama.modeling_llama import LlamaRMSNorm
  14. from transformers.activations import GELUActivation
  15. class ScaledActivation(nn.Module):
  16. """
  17. ScaledActivation module wraps an existing activation function and applies a
  18. scale factor to its output.
  19. Args:
  20. module (nn.Module): The activation function to be scaled.
  21. scales (torch.Tensor): A tensor of size (num_features,) containing the initial
  22. scale factors for each feature.
  23. Returns:
  24. torch.Tensor: The scaled output of the activation function.
  25. """
  26. def __init__(self, module, scales):
  27. super().__init__()
  28. self.act = module
  29. self.scales = nn.Parameter(scales.data)
  30. def forward(self, x):
  31. return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
  32. def set_op_by_name(layer, name, new_module):
  33. """
  34. Set the new module for given module's name.
  35. Args:
  36. layer (nn.Module): The layer in which to replace the submodule.
  37. name (str): The path to the submodule to be replaced, using dot notation
  38. to access nested modules.
  39. new_module (nn.Module): The new module to replace the existing one.
  40. """
  41. levels = name.split(".")
  42. if len(levels) > 1:
  43. mod_ = layer
  44. for l_idx in range(len(levels) - 1):
  45. if levels[l_idx].isdigit():
  46. mod_ = mod_[int(levels[l_idx])]
  47. else:
  48. mod_ = getattr(mod_, levels[l_idx])
  49. setattr(mod_, levels[-1], new_module)
  50. else:
  51. setattr(layer, name, new_module)
  52. def get_op_by_name(module, op_name):
  53. """
  54. Retrieves a submodule within a given layer based on its name.
  55. Args:
  56. module (nn.Module): The layer containing the submodule to find.
  57. op_name (str): The name of the submodule.
  58. Returns:
  59. nn.Module: The requested submodule found within the given layer.
  60. Raises:
  61. ValueError: If the specified submodule cannot be found within the layer.
  62. """
  63. for name, m in module.named_modules():
  64. if name == op_name:
  65. return m
  66. raise ValueError(f"Cannot find op {op_name} in module {module}")
  67. @torch.no_grad()
  68. def scale_ln_fcs(ln, fcs, scales):
  69. """
  70. Scales the weights of a LayerNorm and a list of fully-connected layers proportionally.
  71. Args:
  72. ln (nn.LayerNorm): The LayerNorm module to be scaled.
  73. fcs (List[nn.Linear]): A list of fully-connected layers to be scaled.
  74. scales (torch.Tensor): A 1D tensor of size (num_features,).
  75. """
  76. if not isinstance(fcs, list):
  77. fcs = [fcs]
  78. scales = scales.to(ln.weight.device)
  79. ln.weight.div_(scales)
  80. if hasattr(ln, "bias") and ln.bias is not None:
  81. ln.bias.div_(scales)
  82. for fc in fcs:
  83. fc.weight.mul_(scales.view(1, -1))
  84. for p in ln.parameters():
  85. assert torch.isnan(p).sum() == 0
  86. for fc in fcs:
  87. for p in fc.parameters():
  88. assert torch.isnan(p).sum() == 0
  89. @torch.no_grad()
  90. def scale_fc_fc(fc1, fc2, scales):
  91. """
  92. Scales the weights of two fully-connected layers in a specific pattern.
  93. Args:
  94. fc1 (nn.Linear): The first fully-connected layer to be scaled.
  95. fc2 (nn.Linear): The second fully-connected layer to be scaled.
  96. scales (torch.Tensor): A 1D tensor of size (num_features,).
  97. """
  98. assert isinstance(fc1, nn.Linear)
  99. assert isinstance(fc2, nn.Linear)
  100. scales = scales.to(fc1.weight.device)
  101. fc1.weight[-scales.size(0):].div_(scales.view(-1, 1))
  102. if fc1.bias is not None:
  103. fc1.bias.div_(scales.view(-1))
  104. fc2.weight.mul_(scales.view(1, -1))
  105. for p in fc1.parameters():
  106. assert torch.isnan(p).sum() == 0
  107. for p in fc2.parameters():
  108. assert torch.isnan(p).sum() == 0
  109. @torch.no_grad()
  110. def scale_gelu_fc(gelu, fc, scales):
  111. """
  112. Scales the weight of a GELU activation and a fully-connected layer proportionally.
  113. Args:
  114. gelu (Union[nn.GELU, BloomGelu, GELUActivation]): The GELU activation module to be scaled.
  115. fc (nn.Linear): The fully-connected layer to be scaled.
  116. scales (torch.Tensor): A 1D tensor of size (num_features,).
  117. Raises:
  118. TypeError: If the `gelu` module is not of type `nn.GELU`, `BloomGelu`, or `GELUActivation`.
  119. TypeError: If the `fc` module is not of type `nn.Linear`.
  120. """
  121. assert isinstance(gelu, (nn.GELU, BloomGelu, GELUActivation))
  122. assert isinstance(fc, nn.Linear)
  123. fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
  124. for p in fc.parameters():
  125. assert torch.isnan(p).sum() == 0
  126. def apply_scale(module, scales_list, input_feat_dict=None):
  127. """
  128. Applies different scaling strategies to layers based on their type and hierarchy within a given module.
  129. Args:
  130. module (nn.Module): The module containing the layers to be scaled.
  131. scales_list (List[Tuple[str, List[str], torch.Tensor]]): A list of tuples containing:
  132. * prev_op_name (str): The name of the preceding operation or module,
  133. relative to which the layers to be scaled are located.
  134. * layer_names (List[str]): A list of names of the layers to be scaled, relative to the preceding operation.
  135. * scales (torch.Tensor): A 1D tensor of size (num_features,) containing the scaling factors for each feature.
  136. input_feat_dict (Optional[Dict[str, torch.Tensor]]): A dictionary mapping layer names to their corresponding
  137. input features (optional).
  138. """
  139. for prev_op_name, layer_names, scales in scales_list:
  140. prev_op = get_op_by_name(module, prev_op_name)
  141. layers = [get_op_by_name(module, name) for name in layer_names]
  142. prev_op.cuda()
  143. for layer in layers:
  144. layer.cuda()
  145. scales.cuda()
  146. if isinstance(prev_op, nn.Linear):
  147. assert len(layers) == 1
  148. scale_fc_fc(prev_op, layers[0], scales)
  149. elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)) or "rmsnorm" in str(prev_op.__class__).lower():
  150. scale_ln_fcs(prev_op, layers, scales)
  151. elif isinstance(prev_op, (nn.GELU, BloomGelu, GELUActivation)):
  152. new_module = ScaledActivation(prev_op, scales)
  153. set_op_by_name(module, prev_op_name, new_module)
  154. scale_gelu_fc(prev_op, layers[0], scales)
  155. else:
  156. raise NotImplementedError(f"prev_op {type(prev_op)} not supported yet!")
  157. # apply the scaling to input feat if given; prepare it for clipping
  158. if input_feat_dict is not None:
  159. for layer_name in layer_names:
  160. inp = input_feat_dict[layer_name]
  161. inp.div_(scales.view(1, -1).to(inp.device))
  162. prev_op.cpu()
  163. for layer in layers:
  164. layer.cpu()
  165. scales.cpu()
  166. @torch.no_grad()
  167. def apply_clip(module, clip_list):
  168. """
  169. Applies element-wise clipping to the weight of a specific layer within a given module.
  170. Args:
  171. module (nn.Module): The module containing the layer to be clipped.
  172. clip_list (List[Tuple[str, torch.Tensor]]): A list of tuples containing:
  173. * name (str): The name of the layer to be clipped, relative to the root of the module.
  174. * max_val (torch.Tensor): A 1D or 2D tensor defining the upper bound for each element of the layer's weight.
  175. """
  176. for name, max_val in clip_list:
  177. layer = get_op_by_name(module, name)
  178. layer.cuda()
  179. max_val = max_val.to(layer.weight.device)
  180. org_shape = layer.weight.shape
  181. layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
  182. layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
  183. layer.weight.data = layer.weight.data.reshape(org_shape)
  184. layer.cpu()
  185. def add_scale_weights(model_path, scale_path, tmp_path):
  186. """
  187. Adds pre-computed Activation Weight Quantization (AWQ) results to a model,
  188. including scaling factors and clipping bounds.
  189. Args:
  190. model_path (str): Path to the pre-trained model to be equipped with AWQ.
  191. scale_path (str): Path to the AWQ scale factors (.pt file).
  192. tmp_path (str): Path to the temporary directory where the equipped model will be saved.
  193. """
  194. config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
  195. model = AutoModelForCausalLM.from_pretrained(
  196. model_path, config=config, trust_remote_code=True
  197. )
  198. model.eval()
  199. awq_results = torch.load(str(scale_path), map_location="cpu")
  200. apply_scale(model, awq_results["scale"])
  201. apply_clip(model, awq_results["clip"])
  202. model.save_pretrained(str(tmp_path))
  203. os.system(f"cp {str(model_path)}/tokenizer* {str(tmp_path)}")