|
@@ -12,7 +12,7 @@ import sys
|
|
|
from enum import IntEnum
|
|
from enum import IntEnum
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
from hashlib import sha256
|
|
from hashlib import sha256
|
|
|
-from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast, overload
|
|
|
|
|
|
|
+from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast
|
|
|
|
|
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
import torch
|
|
import torch
|
|
@@ -48,7 +48,6 @@ class Model:
|
|
|
|
|
|
|
|
dir_model: Path
|
|
dir_model: Path
|
|
|
ftype: int
|
|
ftype: int
|
|
|
- fname_out: Path
|
|
|
|
|
is_big_endian: bool
|
|
is_big_endian: bool
|
|
|
endianess: gguf.GGUFEndian
|
|
endianess: gguf.GGUFEndian
|
|
|
use_temp_file: bool
|
|
use_temp_file: bool
|
|
@@ -56,20 +55,20 @@ class Model:
|
|
|
part_names: list[str]
|
|
part_names: list[str]
|
|
|
is_safetensors: bool
|
|
is_safetensors: bool
|
|
|
hparams: dict[str, Any]
|
|
hparams: dict[str, Any]
|
|
|
- gguf_writer: gguf.GGUFWriter
|
|
|
|
|
block_count: int
|
|
block_count: int
|
|
|
tensor_map: gguf.TensorNameMap
|
|
tensor_map: gguf.TensorNameMap
|
|
|
tensor_names: set[str] | None
|
|
tensor_names: set[str] | None
|
|
|
|
|
+ fname_out: Path
|
|
|
|
|
+ gguf_writer: gguf.GGUFWriter
|
|
|
|
|
|
|
|
# subclasses should define this!
|
|
# subclasses should define this!
|
|
|
model_arch: gguf.MODEL_ARCH
|
|
model_arch: gguf.MODEL_ARCH
|
|
|
|
|
|
|
|
- def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool):
|
|
|
|
|
- if self.__class__ == Model:
|
|
|
|
|
- raise TypeError(f"{self.__class__.__name__!r} should not be directly instantiated")
|
|
|
|
|
|
|
+ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool):
|
|
|
|
|
+ if type(self) is Model:
|
|
|
|
|
+ raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
|
|
self.dir_model = dir_model
|
|
self.dir_model = dir_model
|
|
|
self.ftype = ftype
|
|
self.ftype = ftype
|
|
|
- self.fname_out = fname_out
|
|
|
|
|
self.is_big_endian = is_big_endian
|
|
self.is_big_endian = is_big_endian
|
|
|
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
|
|
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
|
|
|
self.use_temp_file = use_temp_file
|
|
self.use_temp_file = use_temp_file
|
|
@@ -79,10 +78,23 @@ class Model:
|
|
|
if not self.is_safetensors:
|
|
if not self.is_safetensors:
|
|
|
self.part_names = Model.get_model_part_names(self.dir_model, ".bin")
|
|
self.part_names = Model.get_model_part_names(self.dir_model, ".bin")
|
|
|
self.hparams = Model.load_hparams(self.dir_model)
|
|
self.hparams = Model.load_hparams(self.dir_model)
|
|
|
- self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
|
|
|
|
|
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
|
|
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
|
|
|
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
|
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
|
|
self.tensor_names = None
|
|
self.tensor_names = None
|
|
|
|
|
+ if self.ftype == gguf.LlamaFileType.GUESSED:
|
|
|
|
|
+ # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
|
|
|
|
|
+ _, first_tensor = next(self.get_tensors())
|
|
|
|
|
+ if first_tensor.dtype == torch.float16:
|
|
|
|
|
+ logger.info(f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})")
|
|
|
|
|
+ self.ftype = gguf.LlamaFileType.MOSTLY_F16
|
|
|
|
|
+ else:
|
|
|
|
|
+ logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
|
|
|
|
|
+ self.ftype = gguf.LlamaFileType.MOSTLY_BF16
|
|
|
|
|
+ ftype_up: str = self.ftype.name.partition("_")[2].upper()
|
|
|
|
|
+ ftype_lw: str = ftype_up.lower()
|
|
|
|
|
+ # allow templating the file name with the output ftype, useful with the "auto" ftype
|
|
|
|
|
+ self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up)
|
|
|
|
|
+ self.gguf_writer = gguf.GGUFWriter(self.fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
|
|
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
|
def __init_subclass__(cls):
|
|
def __init_subclass__(cls):
|
|
@@ -142,14 +154,27 @@ class Model:
|
|
|
raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}")
|
|
raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}")
|
|
|
|
|
|
|
|
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
|
|
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
|
|
|
- name: str = gguf.TENSOR_NAMES[key]
|
|
|
|
|
if key not in gguf.MODEL_TENSORS[self.model_arch]:
|
|
if key not in gguf.MODEL_TENSORS[self.model_arch]:
|
|
|
raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}")
|
|
raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}")
|
|
|
|
|
+ name: str = gguf.TENSOR_NAMES[key]
|
|
|
if "{bid}" in name:
|
|
if "{bid}" in name:
|
|
|
assert bid is not None
|
|
assert bid is not None
|
|
|
name = name.format(bid=bid)
|
|
name = name.format(bid=bid)
|
|
|
return name + suffix
|
|
return name + suffix
|
|
|
|
|
|
|
|
|
|
+ def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int | None, suffix: str = ".weight") -> bool:
|
|
|
|
|
+ if key not in gguf.MODEL_TENSORS[self.model_arch]:
|
|
|
|
|
+ return False
|
|
|
|
|
+ key_name: str = gguf.TENSOR_NAMES[key]
|
|
|
|
|
+ if "{bid}" in key_name:
|
|
|
|
|
+ if bid is None:
|
|
|
|
|
+ return False
|
|
|
|
|
+ key_name = key_name.format(bid=bid)
|
|
|
|
|
+ else:
|
|
|
|
|
+ if bid is not None:
|
|
|
|
|
+ return False
|
|
|
|
|
+ return name == (key_name + suffix)
|
|
|
|
|
+
|
|
|
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
|
|
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
|
|
|
new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
|
|
new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
|
|
|
if new_name is None:
|
|
if new_name is None:
|
|
@@ -215,6 +240,23 @@ class Model:
|
|
|
return False
|
|
return False
|
|
|
|
|
|
|
|
def write_tensors(self):
|
|
def write_tensors(self):
|
|
|
|
|
+ # same as ggml_compute_fp32_to_bf16 in ggml-impl.h
|
|
|
|
|
+ def np_fp32_to_bf16(n: np.ndarray):
|
|
|
|
|
+ # force nan to quiet
|
|
|
|
|
+ n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n)
|
|
|
|
|
+ # flush subnormals to zero
|
|
|
|
|
+ n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n)
|
|
|
|
|
+ # round to nearest even
|
|
|
|
|
+ n = (n + (0x7fff + ((n >> 16) & 1))) >> 16
|
|
|
|
|
+ return n.astype(np.int16)
|
|
|
|
|
+
|
|
|
|
|
+ # Doing this row-wise is much, much faster than element-wise, hence the signature
|
|
|
|
|
+ v_fp32_to_bf16 = np.vectorize(np_fp32_to_bf16, otypes=[np.int16], signature="(n)->(n)")
|
|
|
|
|
+ if self.lazy:
|
|
|
|
|
+ # TODO: find a way to implicitly wrap np.vectorize functions
|
|
|
|
|
+ # NOTE: the type is changed to reflect otypes passed to np.vectorize above
|
|
|
|
|
+ v_fp32_to_bf16 = gguf.LazyNumpyTensor._wrap_fn(v_fp32_to_bf16, meta_noop=np.int16)
|
|
|
|
|
+
|
|
|
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
|
|
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
|
|
|
|
|
|
|
|
for name, data_torch in self.get_tensors():
|
|
for name, data_torch in self.get_tensors():
|
|
@@ -239,35 +281,60 @@ class Model:
|
|
|
data: np.ndarray = data # type hint
|
|
data: np.ndarray = data # type hint
|
|
|
n_dims = len(data.shape)
|
|
n_dims = len(data.shape)
|
|
|
data_dtype = data.dtype
|
|
data_dtype = data.dtype
|
|
|
-
|
|
|
|
|
- # if f32 desired, convert any float16 to float32
|
|
|
|
|
- if self.ftype == 0 and data_dtype == np.float16:
|
|
|
|
|
- data = data.astype(np.float32)
|
|
|
|
|
|
|
+ data_qtype: gguf.GGMLQuantizationType | None = None
|
|
|
|
|
|
|
|
# when both are True, f32 should win
|
|
# when both are True, f32 should win
|
|
|
extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims)
|
|
extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims)
|
|
|
extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims)
|
|
extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims)
|
|
|
|
|
|
|
|
# Most of the codebase that takes in 1D tensors or norms only handles F32 tensors
|
|
# Most of the codebase that takes in 1D tensors or norms only handles F32 tensors
|
|
|
- extra_f32 = extra_f32 or n_dims == 1 or new_name.endswith("_norm.weight")
|
|
|
|
|
|
|
+ # Conditions should closely match those in llama_model_quantize_internal in llama.cpp
|
|
|
|
|
+ extra_f32 = any(cond for cond in (
|
|
|
|
|
+ extra_f32,
|
|
|
|
|
+ n_dims == 1,
|
|
|
|
|
+ new_name.endswith("_norm.weight"),
|
|
|
|
|
+ ))
|
|
|
|
|
+
|
|
|
|
|
+ # Some tensor types are always in float32
|
|
|
|
|
+ extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in (
|
|
|
|
|
+ gguf.MODEL_TENSOR.FFN_GATE_INP,
|
|
|
|
|
+ gguf.MODEL_TENSOR.POS_EMBD,
|
|
|
|
|
+ gguf.MODEL_TENSOR.TOKEN_TYPES,
|
|
|
|
|
+ ))
|
|
|
|
|
|
|
|
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
|
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
|
|
- extra_f16 = extra_f16 or (name.endswith(".weight") and n_dims >= 2)
|
|
|
|
|
-
|
|
|
|
|
- # when both extra_f32 and extra_f16 are False, convert to float32 by default
|
|
|
|
|
- if self.ftype == 1 and data_dtype == np.float16 and (extra_f32 or not extra_f16):
|
|
|
|
|
- data = data.astype(np.float32)
|
|
|
|
|
-
|
|
|
|
|
- if self.ftype == 1 and data_dtype == np.float32 and extra_f16 and not extra_f32:
|
|
|
|
|
- data = data.astype(np.float16)
|
|
|
|
|
|
|
+ extra_f16 = any(cond for cond in (
|
|
|
|
|
+ extra_f16,
|
|
|
|
|
+ (name.endswith(".weight") and n_dims >= 2),
|
|
|
|
|
+ ))
|
|
|
|
|
+
|
|
|
|
|
+ if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
|
|
|
|
|
+ if self.ftype == gguf.LlamaFileType.MOSTLY_F16:
|
|
|
|
|
+ if data_dtype != np.float16:
|
|
|
|
|
+ data = data.astype(np.float16)
|
|
|
|
|
+ data_qtype = gguf.GGMLQuantizationType.F16
|
|
|
|
|
+
|
|
|
|
|
+ elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
|
|
|
|
|
+ if data_dtype != np.float32:
|
|
|
|
|
+ data = data.astype(np.float32)
|
|
|
|
|
+ data = v_fp32_to_bf16(data.view(np.int32))
|
|
|
|
|
+ assert data.dtype == np.int16
|
|
|
|
|
+ data_qtype = gguf.GGMLQuantizationType.BF16
|
|
|
|
|
+
|
|
|
|
|
+ else: # by default, convert to float32
|
|
|
|
|
+ if data_dtype != np.float32:
|
|
|
|
|
+ data = data.astype(np.float32)
|
|
|
|
|
+ data_qtype = gguf.GGMLQuantizationType.F32
|
|
|
|
|
+
|
|
|
|
|
+ assert data_qtype is not None
|
|
|
|
|
|
|
|
# reverse shape to make it similar to the internal ggml dimension order
|
|
# reverse shape to make it similar to the internal ggml dimension order
|
|
|
shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
|
|
shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
|
|
|
|
|
|
|
|
# n_dims is implicit in the shape
|
|
# n_dims is implicit in the shape
|
|
|
- logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data.dtype}, shape = {shape_str}")
|
|
|
|
|
|
|
+ logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
|
|
|
|
|
|
|
|
- self.gguf_writer.add_tensor(new_name, data)
|
|
|
|
|
|
|
+ self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype)
|
|
|
|
|
|
|
|
def write(self):
|
|
def write(self):
|
|
|
self.write_tensors()
|
|
self.write_tensors()
|
|
@@ -2044,12 +2111,6 @@ class BertModel(Model):
|
|
|
|
|
|
|
|
return [(self.map_tensor_name(name), data_torch)]
|
|
return [(self.map_tensor_name(name), data_torch)]
|
|
|
|
|
|
|
|
- def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
|
|
|
|
|
- del new_name, bid, n_dims # unused
|
|
|
|
|
-
|
|
|
|
|
- # not used with get_rows, must be F32
|
|
|
|
|
- return name == "embeddings.token_type_embeddings.weight"
|
|
|
|
|
-
|
|
|
|
|
|
|
|
|
|
@Model.register("NomicBertModel")
|
|
@Model.register("NomicBertModel")
|
|
|
class NomicBertModel(BertModel):
|
|
class NomicBertModel(BertModel):
|
|
@@ -2339,92 +2400,40 @@ class JinaBertV2Model(BertModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
# tree of lazy tensors
|
|
# tree of lazy tensors
|
|
|
-class LazyTorchTensor:
|
|
|
|
|
- _meta: Tensor
|
|
|
|
|
- _data: Tensor | None
|
|
|
|
|
- _args: tuple
|
|
|
|
|
- _func: Callable[[tuple], Tensor] | None
|
|
|
|
|
-
|
|
|
|
|
- def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: tuple = (), func: Callable[[tuple], Tensor] | None = None):
|
|
|
|
|
- self._meta = meta
|
|
|
|
|
- self._data = data
|
|
|
|
|
- self._args = args
|
|
|
|
|
- self._func = func
|
|
|
|
|
-
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
|
|
|
|
|
- # TODO: dict and set
|
|
|
|
|
- if isinstance(o, (list, tuple)):
|
|
|
|
|
- L = []
|
|
|
|
|
- for item in o:
|
|
|
|
|
- L.append(LazyTorchTensor._recurse_apply(item, fn))
|
|
|
|
|
- if isinstance(o, tuple):
|
|
|
|
|
- L = tuple(L)
|
|
|
|
|
- return L
|
|
|
|
|
- elif isinstance(o, LazyTorchTensor):
|
|
|
|
|
- return fn(o)
|
|
|
|
|
- else:
|
|
|
|
|
- return o
|
|
|
|
|
-
|
|
|
|
|
- def _wrap_fn(self, fn: Callable, use_self: bool = False) -> Callable[[Any], LazyTorchTensor]:
|
|
|
|
|
- def wrapped_fn(*args, **kwargs):
|
|
|
|
|
- if kwargs is None:
|
|
|
|
|
- kwargs = {}
|
|
|
|
|
- args = ((self,) if use_self else ()) + args
|
|
|
|
|
-
|
|
|
|
|
- meta_args = LazyTorchTensor._recurse_apply(args, lambda t: t._meta)
|
|
|
|
|
-
|
|
|
|
|
- return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args, func=lambda a: fn(*a, **kwargs))
|
|
|
|
|
- return wrapped_fn
|
|
|
|
|
-
|
|
|
|
|
- def __getattr__(self, __name: str) -> Any:
|
|
|
|
|
- meta_attr = getattr(self._meta, __name)
|
|
|
|
|
- if callable(meta_attr):
|
|
|
|
|
- return self._wrap_fn(getattr(torch.Tensor, __name), use_self=True)
|
|
|
|
|
- elif isinstance(meta_attr, torch.Tensor):
|
|
|
|
|
- # for things like self.T
|
|
|
|
|
- return self._wrap_fn(lambda s: getattr(s, __name))(self)
|
|
|
|
|
- else:
|
|
|
|
|
- return meta_attr
|
|
|
|
|
|
|
+class LazyTorchTensor(gguf.LazyBase):
|
|
|
|
|
+ _tensor_type = torch.Tensor
|
|
|
|
|
+ # to keep the type-checker happy
|
|
|
|
|
+ dtype: torch.dtype
|
|
|
|
|
+ shape: torch.Size
|
|
|
|
|
|
|
|
|
|
+ # only used when converting a torch.Tensor to a np.ndarray
|
|
|
_dtype_map: dict[torch.dtype, type] = {
|
|
_dtype_map: dict[torch.dtype, type] = {
|
|
|
torch.float16: np.float16,
|
|
torch.float16: np.float16,
|
|
|
torch.float32: np.float32,
|
|
torch.float32: np.float32,
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- def numpy(self) -> gguf.LazyTensor:
|
|
|
|
|
|
|
+ def numpy(self) -> gguf.LazyNumpyTensor:
|
|
|
dtype = self._dtype_map[self.dtype]
|
|
dtype = self._dtype_map[self.dtype]
|
|
|
- return gguf.LazyTensor(lambda: LazyTorchTensor.to_eager(self).numpy(), dtype=dtype, shape=self.shape)
|
|
|
|
|
-
|
|
|
|
|
- @overload
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def to_eager(t: Tensor | LazyTorchTensor) -> Tensor: ...
|
|
|
|
|
-
|
|
|
|
|
- @overload
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def to_eager(t: tuple) -> tuple: ...
|
|
|
|
|
-
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def to_eager(t: Any) -> Any:
|
|
|
|
|
- def simple_to_eager(_t: LazyTorchTensor) -> Tensor:
|
|
|
|
|
- # wake up the lazy tensor
|
|
|
|
|
- if _t._data is None and _t._func is not None:
|
|
|
|
|
- # recurse into its arguments
|
|
|
|
|
- _t._args = LazyTorchTensor.to_eager(_t._args)
|
|
|
|
|
- _t._data = _t._func(_t._args)
|
|
|
|
|
- if _t._data is not None:
|
|
|
|
|
- return _t._data
|
|
|
|
|
- else:
|
|
|
|
|
- raise ValueError(f"Could not compute lazy tensor {_t!r} with args {_t._args!r}")
|
|
|
|
|
-
|
|
|
|
|
- # recurse into lists and/or tuples, keeping their structure
|
|
|
|
|
- return LazyTorchTensor._recurse_apply(t, simple_to_eager)
|
|
|
|
|
|
|
+ return gguf.LazyNumpyTensor(
|
|
|
|
|
+ meta=np.lib.stride_tricks.as_strided(np.zeros(1, dtype), self.shape, (0 for _ in self.shape)),
|
|
|
|
|
+ lazy=self._lazy,
|
|
|
|
|
+ args=(self,),
|
|
|
|
|
+ func=(lambda s: s[0].numpy())
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def from_eager(t: Tensor) -> Tensor:
|
|
|
|
|
- if (t.__class__ == LazyTorchTensor):
|
|
|
|
|
|
|
+ @classmethod
|
|
|
|
|
+ def eager_to_meta(cls, t: Tensor) -> Tensor:
|
|
|
|
|
+ if t.is_meta:
|
|
|
return t
|
|
return t
|
|
|
- return LazyTorchTensor(meta=t.detach().to("meta"), data=t) # type: ignore
|
|
|
|
|
|
|
+ return t.detach().to("meta")
|
|
|
|
|
+
|
|
|
|
|
+ @classmethod
|
|
|
|
|
+ def meta_with_dtype(cls, m: Tensor, dtype: torch.dtype) -> Tensor:
|
|
|
|
|
+ m = m.detach()
|
|
|
|
|
+ if not m.is_meta:
|
|
|
|
|
+ m = m.to("meta")
|
|
|
|
|
+ m.dtype = dtype
|
|
|
|
|
+ return m
|
|
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
@@ -2435,28 +2444,8 @@ class LazyTorchTensor:
|
|
|
|
|
|
|
|
if func is torch.Tensor.numpy:
|
|
if func is torch.Tensor.numpy:
|
|
|
return args[0].numpy()
|
|
return args[0].numpy()
|
|
|
- if func is torch.equal:
|
|
|
|
|
- eager_args = LazyTorchTensor.to_eager(args)
|
|
|
|
|
- return func(*eager_args, **kwargs)
|
|
|
|
|
|
|
|
|
|
- return LazyTorchTensor._wrap_fn(args[0], func)(*args, **kwargs)
|
|
|
|
|
-
|
|
|
|
|
- # special methods bypass __getattr__, so they need to be added manually
|
|
|
|
|
- # ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
|
|
|
|
|
- # NOTE: LazyTorchTensor can't be a subclass of Tensor (and then be used
|
|
|
|
|
- # as self._meta is currently used), because then the following
|
|
|
|
|
- # operations would by default not be wrapped, and so not propagated
|
|
|
|
|
- # when the tensor is made eager.
|
|
|
|
|
- # It's better to get non-silent errors for not-yet-supported operators.
|
|
|
|
|
- # TODO: add more when needed to avoid clutter, or find a more concise way
|
|
|
|
|
- def __neg__(self, *args): # mamba
|
|
|
|
|
- return self._wrap_fn(torch.Tensor.__neg__)(self, *args)
|
|
|
|
|
-
|
|
|
|
|
- def __add__(self, *args): # gemma
|
|
|
|
|
- return self._wrap_fn(torch.Tensor.__add__)(self, *args)
|
|
|
|
|
-
|
|
|
|
|
- def __getitem__(self, *args): # bloom falcon refact internlm2
|
|
|
|
|
- return self._wrap_fn(torch.Tensor.__getitem__)(self, *args)
|
|
|
|
|
|
|
+ return LazyTorchTensor._wrap_fn(func)(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
def parse_args() -> argparse.Namespace:
|
|
@@ -2472,11 +2461,11 @@ def parse_args() -> argparse.Namespace:
|
|
|
)
|
|
)
|
|
|
parser.add_argument(
|
|
parser.add_argument(
|
|
|
"--outfile", type=Path,
|
|
"--outfile", type=Path,
|
|
|
- help="path to write to; default: based on input",
|
|
|
|
|
|
|
+ help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
|
|
|
)
|
|
)
|
|
|
parser.add_argument(
|
|
parser.add_argument(
|
|
|
- "--outtype", type=str, choices=["f32", "f16"], default="f16",
|
|
|
|
|
- help="output format - use f32 for float32, f16 for float16",
|
|
|
|
|
|
|
+ "--outtype", type=str, choices=["f32", "f16", "bf16", "auto"], default="f16",
|
|
|
|
|
+ help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
|
|
|
)
|
|
)
|
|
|
parser.add_argument(
|
|
parser.add_argument(
|
|
|
"--bigendian", action="store_true",
|
|
"--bigendian", action="store_true",
|
|
@@ -2530,16 +2519,18 @@ def main() -> None:
|
|
|
logger.error(f'Error: {args.model} is not a directory')
|
|
logger.error(f'Error: {args.model} is not a directory')
|
|
|
sys.exit(1)
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
- ftype_map = {
|
|
|
|
|
- "f32": gguf.GGMLQuantizationType.F32,
|
|
|
|
|
- "f16": gguf.GGMLQuantizationType.F16,
|
|
|
|
|
|
|
+ ftype_map: dict[str, gguf.LlamaFileType] = {
|
|
|
|
|
+ "f32": gguf.LlamaFileType.ALL_F32,
|
|
|
|
|
+ "f16": gguf.LlamaFileType.MOSTLY_F16,
|
|
|
|
|
+ "bf16": gguf.LlamaFileType.MOSTLY_BF16,
|
|
|
|
|
+ "auto": gguf.LlamaFileType.GUESSED,
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if args.outfile is not None:
|
|
if args.outfile is not None:
|
|
|
fname_out = args.outfile
|
|
fname_out = args.outfile
|
|
|
else:
|
|
else:
|
|
|
# output in the same directory as the model by default
|
|
# output in the same directory as the model by default
|
|
|
- fname_out = dir_model / f'ggml-model-{args.outtype}.gguf'
|
|
|
|
|
|
|
+ fname_out = dir_model / 'ggml-model-{ftype}.gguf'
|
|
|
|
|
|
|
|
logger.info(f"Loading model: {dir_model.name}")
|
|
logger.info(f"Loading model: {dir_model.name}")
|
|
|
|
|
|
|
@@ -2555,14 +2546,16 @@ def main() -> None:
|
|
|
logger.info("Set model tokenizer")
|
|
logger.info("Set model tokenizer")
|
|
|
model_instance.set_vocab()
|
|
model_instance.set_vocab()
|
|
|
|
|
|
|
|
|
|
+ model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
|
|
|
|
|
+
|
|
|
if args.vocab_only:
|
|
if args.vocab_only:
|
|
|
- logger.info(f"Exporting model vocab to '{fname_out}'")
|
|
|
|
|
|
|
+ logger.info(f"Exporting model vocab to '{model_instance.fname_out}'")
|
|
|
model_instance.write_vocab()
|
|
model_instance.write_vocab()
|
|
|
else:
|
|
else:
|
|
|
- logger.info(f"Exporting model to '{fname_out}'")
|
|
|
|
|
|
|
+ logger.info(f"Exporting model to '{model_instance.fname_out}'")
|
|
|
model_instance.write()
|
|
model_instance.write()
|
|
|
|
|
|
|
|
- logger.info(f"Model successfully exported to '{fname_out}'")
|
|
|
|
|
|
|
+ logger.info(f"Model successfully exported to '{model_instance.fname_out}'")
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|