|
|
@@ -6389,8 +6389,8 @@ def parse_args() -> argparse.Namespace:
|
|
|
help="model is executed on big endian machine",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
- "model", type=Path,
|
|
|
- help="directory containing model file",
|
|
|
+ "model", type=str,
|
|
|
+ help="directory containing model file or huggingface repository ID (if --remote)",
|
|
|
nargs="?",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
@@ -6493,18 +6493,20 @@ def main() -> None:
|
|
|
else:
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
- dir_model = args.model
|
|
|
-
|
|
|
if args.remote:
|
|
|
+ hf_repo_id = args.model
|
|
|
from huggingface_hub import snapshot_download
|
|
|
local_dir = snapshot_download(
|
|
|
- repo_id=str(dir_model),
|
|
|
+ repo_id=hf_repo_id,
|
|
|
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
|
|
|
dir_model = Path(local_dir)
|
|
|
logger.info(f"Downloaded config and tokenizer to {local_dir}")
|
|
|
+ else:
|
|
|
+ hf_repo_id = None
|
|
|
+ dir_model = Path(args.model)
|
|
|
|
|
|
if not dir_model.is_dir():
|
|
|
- logger.error(f'Error: {args.model} is not a directory')
|
|
|
+ logger.error(f'Error: {dir_model} is not a directory')
|
|
|
sys.exit(1)
|
|
|
|
|
|
ftype_map: dict[str, gguf.LlamaFileType] = {
|
|
|
@@ -6524,9 +6526,9 @@ def main() -> None:
|
|
|
|
|
|
if args.outfile is not None:
|
|
|
fname_out = args.outfile
|
|
|
- elif args.remote:
|
|
|
+ elif hf_repo_id:
|
|
|
# if remote, use the model ID as the output file name
|
|
|
- fname_out = Path("./" + str(args.model).replace("/", "-") + "-{ftype}.gguf")
|
|
|
+ fname_out = Path("./" + hf_repo_id.replace("/", "-") + "-{ftype}.gguf")
|
|
|
else:
|
|
|
fname_out = dir_model
|
|
|
|
|
|
@@ -6555,7 +6557,7 @@ def main() -> None:
|
|
|
split_max_tensors=args.split_max_tensors,
|
|
|
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
|
|
small_first_shard=args.no_tensor_first_split,
|
|
|
- remote_hf_model_id=str(args.model) if args.remote else None)
|
|
|
+ remote_hf_model_id=hf_repo_id)
|
|
|
|
|
|
if args.vocab_only:
|
|
|
logger.info("Exporting model vocab...")
|