|
|
@@ -226,6 +226,9 @@ def get_base_tensor_name(lora_tensor_name: str) -> str:
|
|
|
base_name = lora_tensor_name.replace("base_model.model.", "")
|
|
|
base_name = base_name.replace(".lora_A.weight", ".weight")
|
|
|
base_name = base_name.replace(".lora_B.weight", ".weight")
|
|
|
+ # models produced by mergekit-extract-lora have token embeddings in the adapter
|
|
|
+ base_name = base_name.replace(".lora_embedding_A", ".weight")
|
|
|
+ base_name = base_name.replace(".lora_embedding_B", ".weight")
|
|
|
return base_name
|
|
|
|
|
|
|
|
|
@@ -260,6 +263,10 @@ def parse_args() -> argparse.Namespace:
|
|
|
"--base", type=Path,
|
|
|
help="directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required. If base model is unspecified, it will be loaded from Hugging Face hub based on the adapter config",
|
|
|
)
|
|
|
+ parser.add_argument(
|
|
|
+ "--base-model-id", type=str,
|
|
|
+ help="the model ID of the base model, if it is not available locally or in the adapter config. If specified, it will ignore --base and load the base model config from the Hugging Face hub (Example: 'meta-llama/Llama-3.2-1B-Instruct')",
|
|
|
+ )
|
|
|
parser.add_argument(
|
|
|
"lora_path", type=Path,
|
|
|
help="directory containing Hugging Face PEFT LoRA config (adapter_model.json) and weights (adapter_model.safetensors or adapter_model.bin)",
|
|
|
@@ -290,6 +297,7 @@ if __name__ == '__main__':
|
|
|
|
|
|
dir_base_model: Path | None = args.base
|
|
|
dir_lora: Path = args.lora_path
|
|
|
+ base_model_id: str | None = args.base_model_id
|
|
|
lora_config = dir_lora / "adapter_config.json"
|
|
|
input_model = dir_lora / "adapter_model.safetensors"
|
|
|
|
|
|
@@ -313,7 +321,10 @@ if __name__ == '__main__':
|
|
|
lparams: dict[str, Any] = json.load(f)
|
|
|
|
|
|
# load base model
|
|
|
- if dir_base_model is None:
|
|
|
+ if base_model_id is not None:
|
|
|
+ logger.info(f"Loading base model from Hugging Face: {base_model_id}")
|
|
|
+ hparams = load_hparams_from_hf(base_model_id)
|
|
|
+ elif dir_base_model is None:
|
|
|
if "base_model_name_or_path" in lparams:
|
|
|
model_id = lparams["base_model_name_or_path"]
|
|
|
logger.info(f"Loading base model from Hugging Face: {model_id}")
|
|
|
@@ -371,11 +382,16 @@ if __name__ == '__main__':
|
|
|
if self.lazy:
|
|
|
tensor = LazyTorchTensor.from_eager(tensor)
|
|
|
base_name = get_base_tensor_name(name)
|
|
|
- is_lora_a = ".lora_A.weight" in name
|
|
|
- is_lora_b = ".lora_B.weight" in name
|
|
|
+ # note: mergekit-extract-lora also adds token embeddings to the adapter
|
|
|
+ is_lora_a = ".lora_A.weight" in name or ".lora_embedding_A" in name
|
|
|
+ is_lora_b = ".lora_B.weight" in name or ".lora_embedding_B" in name
|
|
|
if not is_lora_a and not is_lora_b:
|
|
|
if ".base_layer.weight" in name:
|
|
|
continue
|
|
|
+ # mergekit-extract-lora add these layernorm to the adapter, we need to keep them
|
|
|
+ if "_layernorm" in name or ".norm" in name:
|
|
|
+ yield (base_name, tensor)
|
|
|
+ continue
|
|
|
logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
|
|
|
if ".embed_tokens.weight" in name or ".lm_head.weight" in name:
|
|
|
logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning")
|
|
|
@@ -407,9 +423,21 @@ if __name__ == '__main__':
|
|
|
if name == "lm_head.weight" and len(dest) == 0:
|
|
|
raise ValueError("lm_head is present in adapter, but is ignored in base model")
|
|
|
for dest_name, dest_data in dest:
|
|
|
+ # mergekit-extract-lora add these layernorm to the adapter
|
|
|
+ if "_norm" in dest_name:
|
|
|
+ assert dest_data.dim() == 1
|
|
|
+ yield (dest_name, dest_data)
|
|
|
+ continue
|
|
|
+
|
|
|
+ # otherwise, we must get the lora_A and lora_B tensors
|
|
|
assert isinstance(dest_data, LoraTorchTensor)
|
|
|
lora_a, lora_b = dest_data.get_lora_A_B()
|
|
|
|
|
|
+ # note: mergekit-extract-lora flip and transpose A and B
|
|
|
+ # here we only need to transpose token_embd.lora_a, see llm_build_inp_embd()
|
|
|
+ if "token_embd.weight" in dest_name:
|
|
|
+ lora_a = lora_a.T
|
|
|
+
|
|
|
yield (dest_name + ".lora_a", lora_a)
|
|
|
yield (dest_name + ".lora_b", lora_b)
|
|
|
|