|
@@ -93,13 +93,15 @@ class ModelBase:
|
|
|
# Mistral format specifics
|
|
# Mistral format specifics
|
|
|
is_mistral_format: bool = False
|
|
is_mistral_format: bool = False
|
|
|
disable_mistral_community_chat_template: bool = False
|
|
disable_mistral_community_chat_template: bool = False
|
|
|
|
|
+ sentence_transformers_dense_modules: bool = False
|
|
|
|
|
|
|
|
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
|
|
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
|
|
|
use_temp_file: bool = False, eager: bool = False,
|
|
use_temp_file: bool = False, eager: bool = False,
|
|
|
metadata_override: Path | None = None, model_name: str | None = None,
|
|
metadata_override: Path | None = None, model_name: str | None = None,
|
|
|
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
|
|
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
|
|
|
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
|
|
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
|
|
|
- disable_mistral_community_chat_template: bool = False):
|
|
|
|
|
|
|
+ disable_mistral_community_chat_template: bool = False,
|
|
|
|
|
+ sentence_transformers_dense_modules: bool = False):
|
|
|
if type(self) is ModelBase or \
|
|
if type(self) is ModelBase or \
|
|
|
type(self) is TextModel or \
|
|
type(self) is TextModel or \
|
|
|
type(self) is MmprojModel:
|
|
type(self) is MmprojModel:
|
|
@@ -114,6 +116,7 @@ class ModelBase:
|
|
|
self.lazy = not eager or (remote_hf_model_id is not None)
|
|
self.lazy = not eager or (remote_hf_model_id is not None)
|
|
|
self.dry_run = dry_run
|
|
self.dry_run = dry_run
|
|
|
self.remote_hf_model_id = remote_hf_model_id
|
|
self.remote_hf_model_id = remote_hf_model_id
|
|
|
|
|
+ self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
|
|
|
if remote_hf_model_id is not None:
|
|
if remote_hf_model_id is not None:
|
|
|
self.is_safetensors = True
|
|
self.is_safetensors = True
|
|
|
|
|
|
|
@@ -5269,6 +5272,53 @@ class Gemma3Model(TextModel):
|
|
|
@ModelBase.register("Gemma3TextModel")
|
|
@ModelBase.register("Gemma3TextModel")
|
|
|
class EmbeddingGemma(Gemma3Model):
|
|
class EmbeddingGemma(Gemma3Model):
|
|
|
model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING
|
|
model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING
|
|
|
|
|
+ module_paths = []
|
|
|
|
|
+ dense_features_dims = {}
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, *args, **kwargs):
|
|
|
|
|
+ super().__init__(*args, **kwargs)
|
|
|
|
|
+ if self.sentence_transformers_dense_modules:
|
|
|
|
|
+ # read modules.json to determine if model has Dense layers
|
|
|
|
|
+ modules_file = self.dir_model / "modules.json"
|
|
|
|
|
+ if modules_file.is_file():
|
|
|
|
|
+ with open(modules_file, encoding="utf-8") as modules_json_file:
|
|
|
|
|
+ mods = json.load(modules_json_file)
|
|
|
|
|
+ for mod in mods:
|
|
|
|
|
+ if mod["type"] == "sentence_transformers.models.Dense":
|
|
|
|
|
+ mod_path = mod["path"]
|
|
|
|
|
+ # check if model.safetensors file for Dense layer exists
|
|
|
|
|
+ model_tensors_file = self.dir_model / mod_path / "model.safetensors"
|
|
|
|
|
+ if model_tensors_file.is_file():
|
|
|
|
|
+ self.module_paths.append(mod_path)
|
|
|
|
|
+ # read config.json of the Dense layer to get in/out features
|
|
|
|
|
+ mod_conf_file = self.dir_model / mod_path / "config.json"
|
|
|
|
|
+ if mod_conf_file.is_file():
|
|
|
|
|
+ with open(mod_conf_file, encoding="utf-8") as mod_conf_json_file:
|
|
|
|
|
+ mod_conf = json.load(mod_conf_json_file)
|
|
|
|
|
+ # hparams dense_2_feat_out and dense_3_feat_in are required when loading model's dense weights
|
|
|
|
|
+ prefix = self._get_dense_prefix(mod_path)
|
|
|
|
|
+ if mod_conf["in_features"] is not None and mod_conf["out_features"] is not None:
|
|
|
|
|
+ self.dense_features_dims[prefix] = (mod_conf["in_features"], mod_conf["out_features"])
|
|
|
|
|
+
|
|
|
|
|
+ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
|
|
|
|
+ from safetensors.torch import load_file
|
|
|
|
|
+ module_paths = list(self.module_paths)
|
|
|
|
|
+ for i, module_path in enumerate(module_paths):
|
|
|
|
|
+ tensors_file = self.dir_model / module_path / "model.safetensors"
|
|
|
|
|
+ local_tensors = load_file(tensors_file)
|
|
|
|
|
+ tensor_name = self._get_dense_prefix(module_path)
|
|
|
|
|
+ for name, local_tensor in local_tensors.items():
|
|
|
|
|
+ if not name.endswith(".weight"):
|
|
|
|
|
+ continue
|
|
|
|
|
+ orig_name = name.replace("linear", tensor_name)
|
|
|
|
|
+ name = self.map_tensor_name(orig_name)
|
|
|
|
|
+ yield name, local_tensor.clone()
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def _get_dense_prefix(module_path) -> str:
|
|
|
|
|
+ """Get the tensor name prefix for the Dense layer from module path."""
|
|
|
|
|
+ tensor_name = "dense_2" if module_path == "2_Dense" else "dense_3"
|
|
|
|
|
+ return tensor_name
|
|
|
|
|
|
|
|
def set_gguf_parameters(self):
|
|
def set_gguf_parameters(self):
|
|
|
super().set_gguf_parameters()
|
|
super().set_gguf_parameters()
|
|
@@ -5285,6 +5335,10 @@ class EmbeddingGemma(Gemma3Model):
|
|
|
logger.info(f"Using original sliding_window from config: {orig_sliding_window} "
|
|
logger.info(f"Using original sliding_window from config: {orig_sliding_window} "
|
|
|
f"instead of {self.hparams['sliding_window']}")
|
|
f"instead of {self.hparams['sliding_window']}")
|
|
|
self.gguf_writer.add_sliding_window(orig_sliding_window)
|
|
self.gguf_writer.add_sliding_window(orig_sliding_window)
|
|
|
|
|
+ if self.sentence_transformers_dense_modules:
|
|
|
|
|
+ for dense, dims in self.dense_features_dims.items():
|
|
|
|
|
+ logger.info(f"Setting dense layer {dense} in/out features to {dims}")
|
|
|
|
|
+ self.gguf_writer.add_dense_features_dims(dense, dims[0], dims[1])
|
|
|
|
|
|
|
|
self._try_set_pooling_type()
|
|
self._try_set_pooling_type()
|
|
|
|
|
|
|
@@ -9335,6 +9389,13 @@ def parse_args() -> argparse.Namespace:
|
|
|
)
|
|
)
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+ parser.add_argument(
|
|
|
|
|
+ "--sentence-transformers-dense-modules", action="store_true",
|
|
|
|
|
+ help=("Whether to include sentence-transformers dense modules."
|
|
|
|
|
+ "It can be used for sentence-transformers models, like google/embeddinggemma-300m"
|
|
|
|
|
+ "Default these modules are not included.")
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
args = parser.parse_args()
|
|
args = parser.parse_args()
|
|
|
if not args.print_supported_models and args.model is None:
|
|
if not args.print_supported_models and args.model is None:
|
|
|
parser.error("the following arguments are required: model")
|
|
parser.error("the following arguments are required: model")
|
|
@@ -9397,9 +9458,13 @@ def main() -> None:
|
|
|
if args.remote:
|
|
if args.remote:
|
|
|
hf_repo_id = args.model
|
|
hf_repo_id = args.model
|
|
|
from huggingface_hub import snapshot_download
|
|
from huggingface_hub import snapshot_download
|
|
|
|
|
+ allowed_patterns = ["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"]
|
|
|
|
|
+ if args.sentence_transformers_dense_modules:
|
|
|
|
|
+ # include sentence-transformers dense modules safetensors files
|
|
|
|
|
+ allowed_patterns.append("*.safetensors")
|
|
|
local_dir = snapshot_download(
|
|
local_dir = snapshot_download(
|
|
|
repo_id=hf_repo_id,
|
|
repo_id=hf_repo_id,
|
|
|
- allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
|
|
|
|
|
|
|
+ allow_patterns=allowed_patterns)
|
|
|
dir_model = Path(local_dir)
|
|
dir_model = Path(local_dir)
|
|
|
logger.info(f"Downloaded config and tokenizer to {local_dir}")
|
|
logger.info(f"Downloaded config and tokenizer to {local_dir}")
|
|
|
else:
|
|
else:
|
|
@@ -9467,7 +9532,8 @@ def main() -> None:
|
|
|
split_max_tensors=args.split_max_tensors,
|
|
split_max_tensors=args.split_max_tensors,
|
|
|
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
|
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
|
|
small_first_shard=args.no_tensor_first_split,
|
|
small_first_shard=args.no_tensor_first_split,
|
|
|
- remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template
|
|
|
|
|
|
|
+ remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
|
|
|
|
|
+ sentence_transformers_dense_modules=args.sentence_transformers_dense_modules
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
if args.vocab_only:
|
|
if args.vocab_only:
|