|
|
@@ -3,6 +3,9 @@
|
|
|
import os
|
|
|
import sys
|
|
|
import torch
|
|
|
+import transformers
|
|
|
+import json
|
|
|
+import textwrap
|
|
|
import numpy as np
|
|
|
from pathlib import Path
|
|
|
|
|
|
@@ -243,3 +246,54 @@ def compare_tokens(original, converted, type_suffix="", output_dir="data"):
|
|
|
print(f" ... and {len(mismatches) - num_to_show} more mismatches")
|
|
|
|
|
|
return False
|
|
|
+
|
|
|
+
|
|
|
+def show_version_warning(current_version, model_version):
|
|
|
+ if not model_version:
|
|
|
+ return False
|
|
|
+
|
|
|
+ try:
|
|
|
+ from packaging.version import parse, InvalidVersion
|
|
|
+ try:
|
|
|
+ return parse(current_version) < parse(model_version)
|
|
|
+ except InvalidVersion:
|
|
|
+ return current_version != model_version
|
|
|
+ except ImportError:
|
|
|
+ return current_version != model_version
|
|
|
+
|
|
|
+def get_model_transformers_version(model_path):
|
|
|
+ if not model_path:
|
|
|
+ return None
|
|
|
+
|
|
|
+ config_path = Path(model_path) / "config.json"
|
|
|
+ if not config_path.is_file():
|
|
|
+ return None
|
|
|
+
|
|
|
+ try:
|
|
|
+ with open(config_path, "r", encoding="utf-8") as f:
|
|
|
+ config = json.load(f)
|
|
|
+ return config.get("transformers_version")
|
|
|
+ except (IOError, json.JSONDecodeError) as e:
|
|
|
+ print(f"Warning: Could not read or parse {config_path}: {e}", file=sys.stderr)
|
|
|
+ return None
|
|
|
+
|
|
|
+def exit_with_warning(message, model_path):
|
|
|
+ print(message)
|
|
|
+
|
|
|
+ if model_path and transformers is not None:
|
|
|
+ model_transformers_version = get_model_transformers_version(model_path)
|
|
|
+ transformers_version = transformers.__version__
|
|
|
+ if show_version_warning(transformers_version, model_transformers_version):
|
|
|
+ warning_message = f"""
|
|
|
+ =====================================================================
|
|
|
+ Verification failure might be due to a transformers version mismatch:
|
|
|
+
|
|
|
+ Current transformers version: {transformers_version}
|
|
|
+ Model's required version : {model_transformers_version}
|
|
|
+
|
|
|
+ Consider installing the version specified by the model's config:
|
|
|
+ pip install transformers=={model_transformers_version}
|
|
|
+ =====================================================================
|
|
|
+ """
|
|
|
+ print(textwrap.dedent(warning_message))
|
|
|
+ sys.exit(1)
|