|
|
@@ -16,13 +16,14 @@ import re
|
|
|
import signal
|
|
|
import struct
|
|
|
import sys
|
|
|
+import textwrap
|
|
|
import time
|
|
|
import zipfile
|
|
|
-from abc import ABCMeta, abstractmethod
|
|
|
+from abc import ABC, abstractmethod
|
|
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
|
|
from dataclasses import dataclass
|
|
|
from pathlib import Path
|
|
|
-from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, TypeVar
|
|
|
+from typing import TYPE_CHECKING, Any, Callable, ClassVar, IO, Iterable, Literal, Protocol, TypeVar, runtime_checkable
|
|
|
|
|
|
import numpy as np
|
|
|
from sentencepiece import SentencePieceProcessor
|
|
|
@@ -43,6 +44,9 @@ ARCH = gguf.MODEL_ARCH.LLAMA
|
|
|
|
|
|
DEFAULT_CONCURRENCY = 8
|
|
|
|
|
|
+ADDED_TOKENS_FILE = 'added_tokens.json'
|
|
|
+FAST_TOKENIZER_FILE = 'tokenizer.json'
|
|
|
+
|
|
|
#
|
|
|
# data types
|
|
|
#
|
|
|
@@ -188,8 +192,10 @@ class Params:
|
|
|
n_layer = next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model)
|
|
|
|
|
|
if n_layer < 1:
|
|
|
- raise Exception("failed to guess 'n_layer'. This model is unknown or unsupported.\n"
|
|
|
- "Suggestion: provide 'config.json' of the model in the same directory containing model files.")
|
|
|
+ msg = """\
|
|
|
+ failed to guess 'n_layer'. This model is unknown or unsupported.
|
|
|
+ Suggestion: provide 'config.json' of the model in the same directory containing model files."""
|
|
|
+ raise KeyError(textwrap.dedent(msg))
|
|
|
|
|
|
n_head = n_embd // 128 # guessed
|
|
|
n_mult = 256 # guessed
|
|
|
@@ -211,7 +217,8 @@ class Params:
|
|
|
|
|
|
@staticmethod
|
|
|
def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
|
|
|
- config = json.load(open(config_path))
|
|
|
+ with open(config_path) as f:
|
|
|
+ config = json.load(f)
|
|
|
|
|
|
rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None
|
|
|
rope_scaling = config.get("rope_scaling")
|
|
|
@@ -233,8 +240,10 @@ class Params:
|
|
|
elif "max_position_embeddings" in config:
|
|
|
n_ctx = config["max_position_embeddings"]
|
|
|
else:
|
|
|
- raise Exception("failed to guess 'n_ctx'. This model is unknown or unsupported.\n"
|
|
|
- "Suggestion: provide 'config.json' of the model in the same directory containing model files.")
|
|
|
+ msg = """\
|
|
|
+ failed to guess 'n_ctx'. This model is unknown or unsupported.
|
|
|
+ Suggestion: provide 'config.json' of the model in the same directory containing model files."""
|
|
|
+ raise KeyError(textwrap.dedent(msg))
|
|
|
|
|
|
n_experts = None
|
|
|
n_experts_used = None
|
|
|
@@ -265,7 +274,8 @@ class Params:
|
|
|
# {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1}
|
|
|
@staticmethod
|
|
|
def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
|
|
|
- config = json.load(open(config_path))
|
|
|
+ with open(config_path) as f:
|
|
|
+ config = json.load(f)
|
|
|
|
|
|
n_experts = None
|
|
|
n_experts_used = None
|
|
|
@@ -331,47 +341,86 @@ class Params:
|
|
|
# vocab
|
|
|
#
|
|
|
|
|
|
-class BpeVocab:
|
|
|
+@runtime_checkable
|
|
|
+class BaseVocab(Protocol):
|
|
|
+ tokenizer_model: ClassVar[str]
|
|
|
+ name: ClassVar[str]
|
|
|
+
|
|
|
+
|
|
|
+class NoVocab(BaseVocab):
|
|
|
+ tokenizer_model = "no_vocab"
|
|
|
+ name = "no_vocab"
|
|
|
+
|
|
|
+ def __repr__(self) -> str:
|
|
|
+ return "<NoVocab for a model without integrated vocabulary>"
|
|
|
+
|
|
|
+
|
|
|
+@runtime_checkable
|
|
|
+class Vocab(BaseVocab, Protocol):
|
|
|
+ vocab_size: int
|
|
|
+ added_tokens_dict: dict[str, int]
|
|
|
+ added_tokens_list: list[str]
|
|
|
+ fname_tokenizer: Path
|
|
|
+
|
|
|
+ def __init__(self, base_path: Path): ...
|
|
|
+ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ...
|
|
|
+
|
|
|
+
|
|
|
+class BpeVocab(Vocab):
|
|
|
tokenizer_model = "gpt2"
|
|
|
name = "bpe"
|
|
|
|
|
|
- def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
|
|
|
- self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read())
|
|
|
- if isinstance(self.bpe_tokenizer.get('model'), dict):
|
|
|
- self.vocab = self.bpe_tokenizer["model"]["vocab"]
|
|
|
- else:
|
|
|
- self.vocab = self.bpe_tokenizer
|
|
|
- added_tokens: dict[str, int]
|
|
|
- if fname_added_tokens is not None:
|
|
|
- # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
|
|
|
- added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
|
|
|
+ def __init__(self, base_path: Path):
|
|
|
+ added_tokens: dict[str, int] = {}
|
|
|
+
|
|
|
+ if (fname_tokenizer := base_path / 'vocab.json').exists():
|
|
|
+ # "slow" tokenizer
|
|
|
+ with open(fname_tokenizer, encoding="utf-8") as f:
|
|
|
+ self.vocab = json.load(f)
|
|
|
+
|
|
|
+ try:
|
|
|
+ # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
|
|
|
+ with open(base_path / ADDED_TOKENS_FILE, encoding="utf-8") as f:
|
|
|
+ added_tokens = json.load(f)
|
|
|
+ except FileNotFoundError:
|
|
|
+ pass
|
|
|
else:
|
|
|
- # Fall back to trying to find the added tokens in tokenizer.json
|
|
|
- tokenizer_json_file = fname_tokenizer.parent / 'tokenizer.json'
|
|
|
- if not tokenizer_json_file.is_file():
|
|
|
- added_tokens = {}
|
|
|
- else:
|
|
|
- tokenizer_json = json.load(open(tokenizer_json_file, encoding="utf-8"))
|
|
|
- added_tokens = dict(
|
|
|
- (item['content'], item['id'])
|
|
|
- for item in tokenizer_json.get('added_tokens', [])
|
|
|
- # Added tokens here can be duplicates of the main vocabulary.
|
|
|
- if item['content'] not in self.bpe_tokenizer)
|
|
|
-
|
|
|
- vocab_size: int = len(self.vocab)
|
|
|
- expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
|
|
|
- actual_ids = sorted(added_tokens.values())
|
|
|
+ # "fast" tokenizer
|
|
|
+ fname_tokenizer = base_path / FAST_TOKENIZER_FILE
|
|
|
+
|
|
|
+ # if this fails, FileNotFoundError propagates to caller
|
|
|
+ with open(fname_tokenizer, encoding="utf-8") as f:
|
|
|
+ tokenizer_json = json.load(f)
|
|
|
+
|
|
|
+ tokenizer_model: dict[str, Any] = tokenizer_json['model']
|
|
|
+ if (
|
|
|
+ tokenizer_model['type'] != 'BPE' or tokenizer_model.get('byte_fallback', False)
|
|
|
+ or tokenizer_json['decoder']['type'] != 'ByteLevel'
|
|
|
+ ):
|
|
|
+ raise FileNotFoundError('Cannot find GPT-2 BPE tokenizer')
|
|
|
+
|
|
|
+ self.vocab = tokenizer_model["vocab"]
|
|
|
+
|
|
|
+ if (added := tokenizer_json.get('added_tokens')) is not None:
|
|
|
+ # Added tokens here can be duplicates of the main vocabulary.
|
|
|
+ added_tokens = {item['content']: item['id']
|
|
|
+ for item in added
|
|
|
+ if item['content'] not in self.vocab}
|
|
|
+
|
|
|
+ vocab_size = len(self.vocab)
|
|
|
+ expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
|
|
|
+ actual_ids = sorted(added_tokens.values())
|
|
|
if expected_ids != actual_ids:
|
|
|
expected_end_id = vocab_size + len(actual_ids) - 1
|
|
|
- raise Exception(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range {vocab_size} - {expected_end_id}; got {actual_ids}")
|
|
|
+ raise ValueError(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range "
|
|
|
+ f"{vocab_size} - {expected_end_id}; got {actual_ids}")
|
|
|
|
|
|
items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
|
|
|
self.added_tokens_dict = added_tokens
|
|
|
self.added_tokens_list = [text for (text, idx) in items]
|
|
|
- self.vocab_size_base: int = vocab_size
|
|
|
- self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list)
|
|
|
+ self.vocab_size_base = vocab_size
|
|
|
+ self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
|
|
|
self.fname_tokenizer = fname_tokenizer
|
|
|
- self.fname_added_tokens = fname_added_tokens
|
|
|
|
|
|
def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
|
|
reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()}
|
|
|
@@ -392,19 +441,25 @@ class BpeVocab:
|
|
|
return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
|
|
|
|
|
|
|
|
-class SentencePieceVocab:
|
|
|
+class SentencePieceVocab(Vocab):
|
|
|
tokenizer_model = "llama"
|
|
|
name = "spm"
|
|
|
|
|
|
- def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
|
|
|
- self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
|
|
|
- added_tokens: dict[str, int]
|
|
|
- if fname_added_tokens is not None:
|
|
|
- added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
|
|
|
- else:
|
|
|
- added_tokens = {}
|
|
|
+ def __init__(self, base_path: Path):
|
|
|
+ added_tokens: dict[str, int] = {}
|
|
|
+ if (fname_tokenizer := base_path / 'tokenizer.model').exists():
|
|
|
+ # normal location
|
|
|
+ try:
|
|
|
+ with open(base_path / ADDED_TOKENS_FILE, encoding="utf-8") as f:
|
|
|
+ added_tokens = json.load(f)
|
|
|
+ except FileNotFoundError:
|
|
|
+ pass
|
|
|
+ elif not (fname_tokenizer := base_path.parent / 'tokenizer.model').exists():
|
|
|
+ # not found in alternate location either
|
|
|
+ raise FileNotFoundError('Cannot find tokenizer.model')
|
|
|
|
|
|
- vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
|
|
|
+ self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
|
|
|
+ vocab_size = self.sentencepiece_tokenizer.vocab_size()
|
|
|
|
|
|
new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
|
|
|
expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens)))
|
|
|
@@ -414,18 +469,17 @@ class SentencePieceVocab:
|
|
|
raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
|
|
|
|
|
|
# Token pieces that were added to the base vocabulary.
|
|
|
- self.added_tokens_dict = added_tokens
|
|
|
+ self.added_tokens_dict = added_tokens
|
|
|
self.added_tokens_list = [new_tokens[id] for id in actual_new_ids]
|
|
|
self.vocab_size_base = vocab_size
|
|
|
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
|
|
|
self.fname_tokenizer = fname_tokenizer
|
|
|
- self.fname_added_tokens = fname_added_tokens
|
|
|
|
|
|
def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
|
|
tokenizer = self.sentencepiece_tokenizer
|
|
|
for i in range(tokenizer.vocab_size()):
|
|
|
piece = tokenizer.id_to_piece(i)
|
|
|
- text: bytes = piece.encode("utf-8")
|
|
|
+ text = piece.encode("utf-8")
|
|
|
score: float = tokenizer.get_score(i)
|
|
|
|
|
|
toktype = gguf.TokenType.NORMAL
|
|
|
@@ -458,27 +512,42 @@ class SentencePieceVocab:
|
|
|
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
|
|
|
|
|
|
|
|
-class HfVocab:
|
|
|
+class LlamaHfVocab(Vocab):
|
|
|
tokenizer_model = "llama"
|
|
|
name = "hfft"
|
|
|
|
|
|
- def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None = None) -> None:
|
|
|
+ def __init__(self, base_path: Path, ignore_nonllama: bool = False):
|
|
|
+ fname_tokenizer = base_path / FAST_TOKENIZER_FILE
|
|
|
+ # if this fails, FileNotFoundError propagates to caller
|
|
|
+ with open(fname_tokenizer, encoding='utf-8') as f:
|
|
|
+ tokenizer_json = json.load(f)
|
|
|
+
|
|
|
+ # pre-check so we know if we need transformers
|
|
|
+ tokenizer_model: dict[str, Any] = tokenizer_json['model']
|
|
|
+ if ignore_nonllama:
|
|
|
+ pass # workaround incorrect use of this class for WordPiece
|
|
|
+ elif (
|
|
|
+ tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False)
|
|
|
+ or tokenizer_json['decoder']['type'] != 'Sequence'
|
|
|
+ ):
|
|
|
+ raise FileNotFoundError('Cannot find Llama BPE tokenizer')
|
|
|
+
|
|
|
try:
|
|
|
from transformers import AutoTokenizer
|
|
|
except ImportError as e:
|
|
|
raise ImportError(
|
|
|
- "To use HfVocab, please install the `transformers` package. "
|
|
|
+ "To use LlamaHfVocab, please install the `transformers` package. "
|
|
|
"You can install it with `pip install transformers`."
|
|
|
) from e
|
|
|
|
|
|
- print("fname_tokenizer:", fname_tokenizer)
|
|
|
# Allow the tokenizer to default to slow or fast versions.
|
|
|
# Explicitly set tokenizer to use local paths.
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
|
- fname_tokenizer,
|
|
|
- cache_dir=fname_tokenizer,
|
|
|
+ base_path,
|
|
|
+ cache_dir=base_path,
|
|
|
local_files_only=True,
|
|
|
)
|
|
|
+ assert self.tokenizer.is_fast # assume tokenizer.json is used
|
|
|
|
|
|
# Initialize lists and dictionaries for added tokens
|
|
|
self.added_tokens_list = []
|
|
|
@@ -506,8 +575,7 @@ class HfVocab:
|
|
|
self.vocab_size_base = self.tokenizer.vocab_size
|
|
|
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
|
|
|
|
|
|
- self.fname_tokenizer = fname_tokenizer
|
|
|
- self.fname_added_tokens = fname_added_tokens
|
|
|
+ self.fname_tokenizer = fname_tokenizer
|
|
|
|
|
|
def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
|
|
reverse_vocab = {
|
|
|
@@ -559,18 +627,7 @@ class HfVocab:
|
|
|
yield from self.added_tokens()
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
- return f"<HfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
|
|
-
|
|
|
-
|
|
|
-class NoVocab:
|
|
|
- tokenizer_model = "no_vocab"
|
|
|
- name = "no_vocab"
|
|
|
-
|
|
|
- def __repr__(self) -> str:
|
|
|
- return "<NoVocab for a model without integrated vocabulary>"
|
|
|
-
|
|
|
-
|
|
|
-Vocab: TypeAlias = "BpeVocab | SentencePieceVocab | HfVocab | NoVocab"
|
|
|
+ return f"<LlamaHfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
|
|
|
|
|
|
|
|
#
|
|
|
@@ -588,7 +645,7 @@ def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
|
|
|
.reshape(weights.shape))
|
|
|
|
|
|
|
|
|
-class Tensor(metaclass=ABCMeta):
|
|
|
+class Tensor(ABC):
|
|
|
data_type: DataType
|
|
|
|
|
|
@abstractmethod
|
|
|
@@ -610,7 +667,7 @@ def bf16_to_fp32(bf16_arr: np.ndarray[Any, np.dtype[np.uint16]]) -> NDArray:
|
|
|
|
|
|
|
|
|
class UnquantizedTensor(Tensor):
|
|
|
- def __init__(self, ndarray: NDArray) -> None:
|
|
|
+ def __init__(self, ndarray: NDArray):
|
|
|
assert isinstance(ndarray, np.ndarray)
|
|
|
self.ndarray = ndarray
|
|
|
self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype]
|
|
|
@@ -689,7 +746,7 @@ class ModelPlus:
|
|
|
model: LazyModel
|
|
|
paths: list[Path] # Where this was read from.
|
|
|
format: Literal['ggml', 'torch', 'safetensors', 'none']
|
|
|
- vocab: Vocab | None # For GGML models (which have vocab built in), the vocab.
|
|
|
+ vocab: BaseVocab | None # For GGML models (which have vocab built in), the vocab.
|
|
|
|
|
|
|
|
|
def merge_sharded(models: list[LazyModel]) -> LazyModel:
|
|
|
@@ -698,7 +755,7 @@ def merge_sharded(models: list[LazyModel]) -> LazyModel:
|
|
|
names = {name: None for model in models for name in model}
|
|
|
|
|
|
def convert(name: str) -> LazyTensor:
|
|
|
- lazy_tensors: list[LazyTensor] = [model[name] for model in models]
|
|
|
+ lazy_tensors = [model[name] for model in models]
|
|
|
if len(lazy_tensors) == 1:
|
|
|
# only one file; don't go through this procedure since there might
|
|
|
# be quantized tensors
|
|
|
@@ -719,7 +776,7 @@ def merge_sharded(models: list[LazyModel]) -> LazyModel:
|
|
|
|
|
|
def load() -> UnquantizedTensor:
|
|
|
ndarrays = [load_unquantized(tensor) for tensor in lazy_tensors]
|
|
|
- concatenated: NDArray = np.concatenate(ndarrays, axis=axis)
|
|
|
+ concatenated = np.concatenate(ndarrays, axis=axis)
|
|
|
return UnquantizedTensor(concatenated)
|
|
|
description = 'concatenated[[' + '] | ['.join(lt.description for lt in lazy_tensors) + ']]'
|
|
|
return LazyTensor(load, concatenated_shape, lazy_tensors[0].data_type, description)
|
|
|
@@ -807,10 +864,10 @@ class LazyUnpickler(pickle.Unpickler):
|
|
|
|
|
|
def load(offset: int, elm_count: int) -> NDArray:
|
|
|
dtype = data_type.dtype
|
|
|
- fp = self.zip_file.open(info)
|
|
|
- fp.seek(offset * dtype.itemsize)
|
|
|
- size = elm_count * dtype.itemsize
|
|
|
- data = fp.read(size)
|
|
|
+ with self.zip_file.open(info) as fp:
|
|
|
+ fp.seek(offset * dtype.itemsize)
|
|
|
+ size = elm_count * dtype.itemsize
|
|
|
+ data = fp.read(size)
|
|
|
assert len(data) == size
|
|
|
return np.frombuffer(data, dtype)
|
|
|
description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}'
|
|
|
@@ -831,7 +888,7 @@ class LazyUnpickler(pickle.Unpickler):
|
|
|
def rebuild_from_type_v2(func, new_type, args, state):
|
|
|
return func(*args)
|
|
|
|
|
|
- CLASSES: dict[tuple[str, str], Any] = {
|
|
|
+ CLASSES = {
|
|
|
# getattr used here as a workaround for mypy not being smart enough to determine
|
|
|
# the staticmethods have a __func__ attribute.
|
|
|
('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),
|
|
|
@@ -890,7 +947,7 @@ def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
|
|
|
def must_read(fp: IO[bytes], length: int) -> bytes:
|
|
|
ret = fp.read(length)
|
|
|
if len(ret) < length:
|
|
|
- raise Exception("unexpectedly reached end of file")
|
|
|
+ raise EOFError("unexpectedly reached end of file")
|
|
|
return ret
|
|
|
|
|
|
|
|
|
@@ -948,13 +1005,14 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
|
|
|
yield result
|
|
|
|
|
|
|
|
|
-def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None:
|
|
|
+def check_vocab_size(params: Params, vocab: BaseVocab, pad_vocab: bool = False) -> None:
|
|
|
# Handle special case where the model's vocab size is not set
|
|
|
if params.n_vocab == -1:
|
|
|
raise ValueError(
|
|
|
- f"The model's vocab size is set to -1 in params.json. Please update it manually.{f' Maybe {vocab.vocab_size}?' if hasattr(vocab, 'vocab_size') else ''}"
|
|
|
+ "The model's vocab size is set to -1 in params.json. Please update it manually."
|
|
|
+ + (f" Maybe {vocab.vocab_size}?" if isinstance(vocab, Vocab) else ""),
|
|
|
)
|
|
|
- if isinstance(vocab, NoVocab):
|
|
|
+ if not isinstance(vocab, Vocab):
|
|
|
return # model has no vocab
|
|
|
|
|
|
# Check for a vocab size mismatch
|
|
|
@@ -979,11 +1037,11 @@ def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> N
|
|
|
if vocab.vocab_size < params.n_vocab:
|
|
|
msg += " Add the --pad-vocab option and try again."
|
|
|
|
|
|
- raise Exception(msg)
|
|
|
+ raise ValueError(msg)
|
|
|
|
|
|
|
|
|
class OutputFile:
|
|
|
- def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE) -> None:
|
|
|
+ def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
|
|
|
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
|
|
|
|
|
|
def add_meta_arch(self, params: Params) -> None:
|
|
|
@@ -1034,8 +1092,6 @@ class OutputFile:
|
|
|
self.gguf.add_file_type(params.ftype)
|
|
|
|
|
|
def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list[float], list[gguf.TokenType]]:
|
|
|
- assert not isinstance(vocab, NoVocab)
|
|
|
-
|
|
|
tokens = []
|
|
|
scores = []
|
|
|
toktypes = []
|
|
|
@@ -1135,7 +1191,7 @@ class OutputFile:
|
|
|
|
|
|
@staticmethod
|
|
|
def write_all(
|
|
|
- fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab,
|
|
|
+ fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
|
|
|
concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
|
|
|
pad_vocab: bool = False,
|
|
|
) -> None:
|
|
|
@@ -1145,11 +1201,11 @@ class OutputFile:
|
|
|
|
|
|
# meta data
|
|
|
of.add_meta_arch(params)
|
|
|
- if isinstance(vocab, NoVocab):
|
|
|
- of.gguf.add_tokenizer_model(vocab.tokenizer_model)
|
|
|
- else:
|
|
|
+ if isinstance(vocab, Vocab):
|
|
|
of.add_meta_vocab(vocab)
|
|
|
of.add_meta_special_vocab(svocab)
|
|
|
+ else: # NoVocab
|
|
|
+ of.gguf.add_tokenizer_model(vocab.tokenizer_model)
|
|
|
|
|
|
# tensor info
|
|
|
for name, lazy_tensor in model.items():
|
|
|
@@ -1176,7 +1232,7 @@ def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileT
|
|
|
|
|
|
name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()}
|
|
|
|
|
|
- raise Exception(f"Unexpected combination of types: {name_to_type}")
|
|
|
+ raise ValueError(f"Unexpected combination of types: {name_to_type}")
|
|
|
|
|
|
|
|
|
def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
|
|
|
@@ -1186,7 +1242,7 @@ def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyM
|
|
|
|
|
|
def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) -> LazyModel:
|
|
|
tmap = gguf.TensorNameMap(ARCH, params.n_layer)
|
|
|
- should_skip: set[gguf.MODEL_TENSOR] = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, []))
|
|
|
+ should_skip = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, []))
|
|
|
|
|
|
tmp = model
|
|
|
|
|
|
@@ -1213,8 +1269,7 @@ def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) ->
|
|
|
if skip_unknown:
|
|
|
print(f"Unexpected tensor name: {name} - skipping")
|
|
|
continue
|
|
|
- else:
|
|
|
- raise Exception(f"Unexpected tensor name: {name}. Use --skip-unknown to ignore it (e.g. LLaVA)")
|
|
|
+ raise ValueError(f"Unexpected tensor name: {name}. Use --skip-unknown to ignore it (e.g. LLaVA)")
|
|
|
|
|
|
if tensor_type in should_skip:
|
|
|
print(f"skipping tensor {name_new}")
|
|
|
@@ -1231,7 +1286,7 @@ def nth_multifile_path(path: Path, n: int) -> Path | None:
|
|
|
the nth path in the model.
|
|
|
'''
|
|
|
# Support the following patterns:
|
|
|
- patterns: list[tuple[str, str]] = [
|
|
|
+ patterns = [
|
|
|
# - x.00.pth, x.01.pth, etc.
|
|
|
(r'\.[0-9]{2}\.pth$', f'.{n:02}.pth'),
|
|
|
# - x-00001-of-00002.bin, x-00002-of-00002.bin, etc.
|
|
|
@@ -1277,9 +1332,9 @@ def load_some_model(path: Path) -> ModelPlus:
|
|
|
globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"]
|
|
|
files = [file for glob in globs for file in path.glob(glob)]
|
|
|
if not files:
|
|
|
- raise Exception(f"Can't find model in directory {path}")
|
|
|
+ raise FileNotFoundError(f"Can't find model in directory {path}")
|
|
|
if len(files) > 1:
|
|
|
- raise Exception(f"Found multiple models in {path}, not sure which to pick: {files}")
|
|
|
+ raise ValueError(f"Found multiple models in {path}, not sure which to pick: {files}")
|
|
|
path = files[0]
|
|
|
|
|
|
paths = find_multifile_paths(path)
|
|
|
@@ -1293,36 +1348,14 @@ def load_some_model(path: Path) -> ModelPlus:
|
|
|
|
|
|
|
|
|
class VocabFactory:
|
|
|
- _FILES = {"spm": "tokenizer.model", "bpe": "vocab.json", "hfft": "tokenizer.json"}
|
|
|
+ _VOCAB_CLASSES: list[type[Vocab]] = [SentencePieceVocab, BpeVocab, LlamaHfVocab]
|
|
|
|
|
|
def __init__(self, path: Path):
|
|
|
self.path = path
|
|
|
- self.file_paths = self._detect_files()
|
|
|
- print(f"Found vocab files: {self.file_paths}")
|
|
|
-
|
|
|
- def _detect_files(self) -> dict[str, Path | None]:
|
|
|
- def locate(file: str) -> Path | None:
|
|
|
- if (path := self.path / file).exists():
|
|
|
- return path
|
|
|
- if (path := self.path.parent / file).exists():
|
|
|
- return path
|
|
|
- return None
|
|
|
-
|
|
|
- return {vt: locate(f) for vt, f in self._FILES.items()}
|
|
|
-
|
|
|
- def _select_file(self, vocab_types: list[str]) -> tuple[str, Path]:
|
|
|
- for vtype in vocab_types:
|
|
|
- try:
|
|
|
- path = self.file_paths[vtype]
|
|
|
- except KeyError:
|
|
|
- raise ValueError(f"Unsupported vocabulary type {vtype}") from None
|
|
|
- if path is not None:
|
|
|
- return vtype, path
|
|
|
- raise FileNotFoundError(f"Could not find any of {[self._FILES[vt] for vt in vocab_types]}")
|
|
|
|
|
|
- def _create_special_vocab(self, vocab: Vocab, model_parent_path: Path) -> gguf.SpecialVocab:
|
|
|
+ def _create_special_vocab(self, vocab: BaseVocab, model_parent_path: Path) -> gguf.SpecialVocab:
|
|
|
load_merges = vocab.name == "bpe"
|
|
|
- n_vocab = vocab.vocab_size if hasattr(vocab, "vocab_size") else None
|
|
|
+ n_vocab = vocab.vocab_size if isinstance(vocab, Vocab) else None
|
|
|
return gguf.SpecialVocab(
|
|
|
model_parent_path,
|
|
|
load_merges=load_merges,
|
|
|
@@ -1331,27 +1364,29 @@ class VocabFactory:
|
|
|
)
|
|
|
|
|
|
def _create_vocab_by_path(self, vocab_types: list[str]) -> Vocab:
|
|
|
- vocab_type, path = self._select_file(vocab_types)
|
|
|
- print(f"Loading vocab file {path!r}, type {vocab_type!r}")
|
|
|
+ vocab_classes: dict[str, type[Vocab]] = {cls.name: cls for cls in self._VOCAB_CLASSES}
|
|
|
+ selected_vocabs: dict[str, type[Vocab]] = {}
|
|
|
+ for vtype in vocab_types:
|
|
|
+ try:
|
|
|
+ selected_vocabs[vtype] = vocab_classes[vtype]
|
|
|
+ except KeyError:
|
|
|
+ raise ValueError(f"Unsupported vocabulary type {vtype}") from None
|
|
|
|
|
|
- added_tokens_path = path.parent / "added_tokens.json"
|
|
|
- if vocab_type == "bpe":
|
|
|
- return BpeVocab(
|
|
|
- path, added_tokens_path if added_tokens_path.exists() else None
|
|
|
- )
|
|
|
- if vocab_type == "spm":
|
|
|
- return SentencePieceVocab(
|
|
|
- path, added_tokens_path if added_tokens_path.exists() else None
|
|
|
- )
|
|
|
- if vocab_type == "hfft":
|
|
|
- return HfVocab(
|
|
|
- path.parent, added_tokens_path if added_tokens_path.exists() else None
|
|
|
- )
|
|
|
- raise ValueError(vocab_type)
|
|
|
+ for vtype, cls in selected_vocabs.items():
|
|
|
+ try:
|
|
|
+ vocab = cls(self.path)
|
|
|
+ break
|
|
|
+ except FileNotFoundError:
|
|
|
+ pass # ignore unavailable tokenizers
|
|
|
+ else:
|
|
|
+ raise FileNotFoundError(f"Could not find a tokenizer matching any of {vocab_types}")
|
|
|
+
|
|
|
+ print(f"Loaded vocab file {vocab.fname_tokenizer!r}, type {vocab.name!r}")
|
|
|
+ return vocab
|
|
|
|
|
|
- def load_vocab(self, vocab_types: list[str], model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]:
|
|
|
- vocab: Vocab
|
|
|
- if len(vocab_types) == 1 and "no_vocab" in vocab_types:
|
|
|
+ def load_vocab(self, vocab_types: list[str] | None, model_parent_path: Path) -> tuple[BaseVocab, gguf.SpecialVocab]:
|
|
|
+ vocab: BaseVocab
|
|
|
+ if vocab_types is None:
|
|
|
vocab = NoVocab()
|
|
|
else:
|
|
|
vocab = self._create_vocab_by_path(vocab_types)
|
|
|
@@ -1408,10 +1443,8 @@ def main(args_in: list[str] | None = None) -> None:
|
|
|
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
|
|
|
|
|
|
args = parser.parse_args(args_in)
|
|
|
- if args.no_vocab:
|
|
|
- if args.vocab_only:
|
|
|
- raise ValueError("no need to specify --vocab-only if using --no-vocab")
|
|
|
- args.vocab_type = "no_vocab"
|
|
|
+ if args.no_vocab and args.vocab_only:
|
|
|
+ raise ValueError("--vocab-only does not make sense with --no-vocab")
|
|
|
|
|
|
if args.dump_single:
|
|
|
model_plus = lazy_load_file(args.model)
|
|
|
@@ -1433,10 +1466,12 @@ def main(args_in: list[str] | None = None) -> None:
|
|
|
params = Params.load(model_plus)
|
|
|
if params.n_ctx == -1:
|
|
|
if args.ctx is None:
|
|
|
- raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n"
|
|
|
- "Please specify one with --ctx:\n"
|
|
|
- " - LLaMA v1: --ctx 2048\n"
|
|
|
- " - LLaMA v2: --ctx 4096\n")
|
|
|
+ msg = """\
|
|
|
+ The model doesn't have a context size, and you didn't specify one with --ctx
|
|
|
+ Please specify one with --ctx:
|
|
|
+ - LLaMA v1: --ctx 2048
|
|
|
+ - LLaMA v2: --ctx 4096"""
|
|
|
+ parser.error(textwrap.dedent(msg))
|
|
|
params.n_ctx = args.ctx
|
|
|
|
|
|
if args.outtype:
|
|
|
@@ -1451,9 +1486,11 @@ def main(args_in: list[str] | None = None) -> None:
|
|
|
model_parent_path = model_plus.paths[0].parent
|
|
|
vocab_path = Path(args.vocab_dir or args.model or model_parent_path)
|
|
|
vocab_factory = VocabFactory(vocab_path)
|
|
|
- vocab, special_vocab = vocab_factory.load_vocab(args.vocab_type.split(","), model_parent_path)
|
|
|
+ vocab_types = None if args.no_vocab else args.vocab_type.split(",")
|
|
|
+ vocab, special_vocab = vocab_factory.load_vocab(vocab_types, model_parent_path)
|
|
|
|
|
|
if args.vocab_only:
|
|
|
+ assert isinstance(vocab, Vocab)
|
|
|
if not args.outfile:
|
|
|
raise ValueError("need --outfile if using --vocab-only")
|
|
|
outfile = args.outfile
|