Sfoglia il codice sorgente

convert-hf : save memory with lazy evaluation (#7075)

* convert-hf : begin refactoring write_tensor

* convert : upgrade to sentencepiece v0.2.0

* convert-hf : remove unused n_dims in extra_*_tensors

* convert-hf : simplify MoE weights stacking

* convert-hf : flake8 linter doesn't like semicolons

* convert-hf : allow unusual model part names

For example, loading `model-00001-of-00001.safetensors` now works.

* convert-hf : fix stacking MoE expert tensors

`torch.stack` and `torch.cat` don't do the same thing.

* convert-hf : fix Mamba conversion

Tested to work even with a SentencePiece-based tokenizer.

* convert : use a string for the SentencePiece tokenizer path

* convert-hf : display tensor shape

* convert-hf : convert norms to f32 by default

* convert-hf : sort model part names

`os.listdir` is said to list files in arbitrary order.
Sorting the file names should let "model-00009-of-00042.safetensors"
be loaded before "model-00010-of-00042.safetensors".

* convert-hf : use an ABC for Model again

It seems Protocol can't be used as a statically type-checked ABC,
because its subclasses also can't be instantiated. (why did it seem to work?)

At least there's still a way to throw an error when forgetting to define
the `model_arch` property of any registered Model subclasses.

* convert-hf : use a plain class for Model, and forbid direct instantiation

There are no abstract methods used anyway,
so using ABC isn't really necessary.

* convert-hf : more consistent formatting of cmdline args

* convert-hf : align the message logged for converted tensors

* convert-hf : fix Refact conversion

* convert-hf : save memory with lazy evaluation

* convert-hf : flake8 doesn't like lowercase L as a variable name

* convert-hf : remove einops requirement for InternLM2

* convert-hf : faster model parts loading

Instead of pre-loading them all into a dict, iterate on the tensors
in the model parts progressively as needed in Model.write_tensors

Conversion for some architectures relies on checking for the presence
of specific tensor names, so for multi-part models, the weight map is read
from the relevant json file to quickly get these names up-front.

* convert-hf : minor changes for consistency

* gguf-py : add tqdm as a dependency

It's small, and used for a progress bar
in GGUFWriter.write_tensors_to_file
compilade 1 anno fa
parent
commit
f98eb31c51

File diff suppressed because it is too large
+ 309 - 500
convert-hf-to-gguf.py


+ 12 - 8
convert.py

@@ -284,6 +284,7 @@ class Params:
         n_experts      = None
         n_experts_used = None
         f_rope_freq_base = None
+        n_ff = None
 
         # hack to determine LLaMA v1 vs v2 vs CodeLlama
         if config.get("moe"):
@@ -308,6 +309,8 @@ class Params:
             n_experts_used = config["moe"]["num_experts_per_tok"]
             f_rope_freq_base = 1e6
 
+        assert n_ff is not None
+
         return Params(
             n_vocab          = model["tok_embeddings.weight"].shape[0],
             n_embd           = config["dim"],
@@ -462,7 +465,8 @@ class SentencePieceVocab(Vocab):
             # not found in alternate location either
             raise FileNotFoundError('Cannot find tokenizer.model')
 
-        self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
+        self.sentencepiece_tokenizer = SentencePieceProcessor()
+        self.sentencepiece_tokenizer.LoadFromFile(str(fname_tokenizer))
         vocab_size = self.sentencepiece_tokenizer.vocab_size()
 
         new_tokens       = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
@@ -482,23 +486,23 @@ class SentencePieceVocab(Vocab):
     def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
         tokenizer = self.sentencepiece_tokenizer
         for i in range(tokenizer.vocab_size()):
-            piece = tokenizer.id_to_piece(i)
+            piece = tokenizer.IdToPiece(i)
             text         = piece.encode("utf-8")
-            score: float = tokenizer.get_score(i)
+            score: float = tokenizer.GetScore(i)
 
             toktype = gguf.TokenType.NORMAL
-            if tokenizer.is_unknown(i):
+            if tokenizer.IsUnknown(i):
                 toktype = gguf.TokenType.UNKNOWN
-            if tokenizer.is_control(i):
+            if tokenizer.IsControl(i):
                 toktype = gguf.TokenType.CONTROL
 
             # NOTE: I think added_tokens are user defined.
             # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
             # if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
 
-            if tokenizer.is_unused(i):
+            if tokenizer.IsUnused(i):
                 toktype = gguf.TokenType.UNUSED
-            if tokenizer.is_byte(i):
+            if tokenizer.IsByte(i):
                 toktype = gguf.TokenType.BYTE
 
             yield text, score, toktype
@@ -906,7 +910,7 @@ class LazyUnpickler(pickle.Unpickler):
     def rebuild_from_type_v2(func, new_type, args, state):
         return func(*args)
 
-    CLASSES = {
+    CLASSES: dict[tuple[str, str], type[LazyTensor] | LazyStorageKind] = {
         # getattr used here as a workaround for mypy not being smart enough to determine
         # the staticmethods have a __func__ attribute.
         ('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),

+ 1 - 1
examples/server/tests/features/steps/steps.py

@@ -939,7 +939,7 @@ async def oai_chat_completions(user_prompt,
                     while event_received:
                         event_received = False
                         async for line_in_bytes in response.content:
-                            line = line_in_bytes.decode('utf8')
+                            line = line_in_bytes.decode('utf-8')
                             line = line.rstrip('\n').rstrip('\r')
                             if line == '':
                                 continue

+ 1 - 1
gguf-py/gguf/constants.py

@@ -860,7 +860,7 @@ class GGUFValueType(IntEnum):
 # Note: Does not support GGML_QKK_64
 QK_K = 256
 # Items here are (block size, type size)
-GGML_QUANT_SIZES = {
+GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
     GGMLQuantizationType.F32:     (1, 4),
     GGMLQuantizationType.F16:     (1, 2),
     GGMLQuantizationType.Q4_0:    (32, 2 + 16),

+ 4 - 4
gguf-py/gguf/gguf_reader.py

@@ -65,7 +65,7 @@ class ReaderTensor(NamedTuple):
 
 class GGUFReader:
     # I - same as host, S - swapped
-    byte_order: Literal['I' | 'S'] = 'I'
+    byte_order: Literal['I'] | Literal['S'] = 'I'
     alignment: int = GGUF_DEFAULT_ALIGNMENT
 
     # Note: Internal helper, API may change.
@@ -83,7 +83,7 @@ class GGUFReader:
         GGUFValueType.BOOL:    np.bool_,
     }
 
-    def __init__(self, path: os.PathLike[str] | str, mode: Literal['r' | 'r+' | 'c'] = 'r'):
+    def __init__(self, path: os.PathLike[str] | str, mode: Literal['r'] | Literal['r+'] | Literal['c'] = 'r'):
         self.data = np.memmap(path, mode = mode)
         offs = 0
         if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
@@ -128,7 +128,7 @@ class GGUFReader:
         return self.tensors[idx]
 
     def _get(
-        self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I' | 'S' | '<'] = None,
+        self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I'] | Literal['S'] | Literal['<'] = None,
     ) -> npt.NDArray[Any]:
         count = int(count)
         itemsize = int(np.empty([], dtype = dtype).itemsize)
@@ -250,7 +250,7 @@ class GGUFReader:
                 raise ValueError(f'Found duplicated tensor with name {tensor_name}')
             tensor_names.add(tensor_name)
             ggml_type = GGMLQuantizationType(raw_dtype[0])
-            n_elems = np.prod(dims)
+            n_elems = int(np.prod(dims))
             block_size, type_size = GGML_QUANT_SIZES[ggml_type]
             n_bytes = n_elems * type_size // block_size
             data_offs = int(start_offs + offset_tensor[0])

+ 68 - 9
gguf-py/gguf/gguf_writer.py

@@ -7,7 +7,7 @@ import struct
 import tempfile
 from enum import Enum, auto
 from io import BufferedWriter
-from typing import IO, Any, Sequence, Mapping
+from typing import IO, Any, Callable, Sequence, Mapping
 from string import ascii_letters, digits
 
 import numpy as np
@@ -28,6 +28,47 @@ from .constants import (
 logger = logging.getLogger(__name__)
 
 
+class LazyTensor:
+    data: Callable[[], np.ndarray[Any, Any]]
+    # to avoid too deep recursion
+    functions: list[Callable[[np.ndarray[Any, Any]], np.ndarray[Any, Any]]]
+    dtype: np.dtype[Any]
+    shape: tuple[int, ...]
+
+    def __init__(self, data: Callable[[], np.ndarray[Any, Any]], *, dtype: type, shape: tuple[int, ...]):
+        self.data = data
+        self.functions = []
+        self.dtype = np.dtype(dtype)
+        self.shape = shape
+
+    def astype(self, dtype: type, **kwargs) -> LazyTensor:
+        self.functions.append(lambda n: n.astype(dtype, **kwargs))
+        self.dtype = np.dtype(dtype)
+        return self
+
+    @property
+    def nbytes(self) -> int:
+        size = 1
+        for n in self.shape:
+            size *= n
+        return size * self.dtype.itemsize
+
+    def tofile(self, *args, **kwargs) -> None:
+        data = self.data()
+        for f in self.functions:
+            data = f(data)
+        assert data.shape == self.shape
+        assert data.dtype == self.dtype
+        assert data.nbytes == self.nbytes
+        self.functions = []
+        self.data = lambda: data
+        data.tofile(*args, **kwargs)
+
+    def byteswap(self, *args, **kwargs) -> LazyTensor:
+        self.functions.append(lambda n: n.byteswap(*args, **kwargs))
+        return self
+
+
 class WriterState(Enum):
     EMPTY   = auto()
     HEADER  = auto()
@@ -38,7 +79,7 @@ class WriterState(Enum):
 class GGUFWriter:
     fout: BufferedWriter
     temp_file: tempfile.SpooledTemporaryFile[bytes] | None
-    tensors: list[np.ndarray[Any, Any]]
+    tensors: list[np.ndarray[Any, Any] | LazyTensor]
     _simple_value_packing = {
         GGUFValueType.UINT8:   "B",
         GGUFValueType.INT8:    "b",
@@ -176,7 +217,7 @@ class GGUFWriter:
         if pack_fmt is not None:
             self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
         elif vtype == GGUFValueType.STRING:
-            encoded_val = val.encode("utf8") if isinstance(val, str) else val
+            encoded_val = val.encode("utf-8") if isinstance(val, str) else val
             self.kv_data += self._pack("Q", len(encoded_val))
             self.kv_data += encoded_val
         elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
@@ -205,7 +246,7 @@ class GGUFWriter:
             raise ValueError(f'Duplicated tensor name {name}')
         self.ti_names.add(name)
 
-        encoded_name = name.encode("utf8")
+        encoded_name = name.encode("utf-8")
         self.ti_data += self._pack("Q", len(encoded_name))
         self.ti_data += encoded_name
         n_dims = len(tensor_shape)
@@ -237,7 +278,7 @@ class GGUFWriter:
         self.ti_data_count += 1
 
     def add_tensor(
-        self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
+        self, name: str, tensor: np.ndarray[Any, Any] | LazyTensor, raw_shape: Sequence[int] | None = None,
         raw_dtype: GGMLQuantizationType | None = None,
     ) -> None:
         if self.endianess == GGUFEndian.BIG:
@@ -262,7 +303,7 @@ class GGUFWriter:
         if pad != 0:
             fp.write(bytes([0] * pad))
 
-    def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
+    def write_tensor_data(self, tensor: np.ndarray[Any, Any] | LazyTensor) -> None:
         if self.state is not WriterState.TI_DATA:
             raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
 
@@ -272,15 +313,33 @@ class GGUFWriter:
         tensor.tofile(self.fout)
         self.write_padding(self.fout, tensor.nbytes)
 
-    def write_tensors_to_file(self) -> None:
+    def write_tensors_to_file(self, *, progress: bool = False) -> None:
         self.write_ti_data_to_file()
 
         self.write_padding(self.fout, self.fout.tell())
 
         if self.temp_file is None:
+            self.tensors.reverse()  # to pop from the "beginning" in constant time
+
+            if progress:
+                from tqdm import tqdm
+
+                total_bytes = sum(t.nbytes for t in self.tensors)
+
+                bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
+
+                while True:
+                    try:
+                        tensor = self.tensors.pop()
+                    except IndexError:
+                        break
+                    tensor.tofile(self.fout)
+                    bar.update(tensor.nbytes)
+                    self.write_padding(self.fout, tensor.nbytes)
+                return
             while True:
                 try:
-                    tensor = self.tensors.pop(0)
+                    tensor = self.tensors.pop()
                 except IndexError:
                     break
                 tensor.tofile(self.fout)
@@ -479,7 +538,7 @@ class GGUFWriter:
         self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
 
     def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
-        if isinstance(value, list):
+        if not isinstance(value, str):
             template_default = None
             template_names = set()
 

+ 3 - 3
gguf-py/gguf/vocab.py

@@ -4,7 +4,7 @@ import logging
 import json
 import os
 from pathlib import Path
-from typing import Any, Callable
+from typing import Any, Callable, Sequence, Mapping, Iterable
 
 from .gguf_writer import GGUFWriter
 
@@ -15,11 +15,11 @@ class SpecialVocab:
     merges: list[str]
     add_special_token: dict[str, bool]
     special_token_ids: dict[str, int]
-    chat_template: str | None
+    chat_template: str | Sequence[Mapping[str, str]] | None
 
     def __init__(
         self, path: str | os.PathLike[str], load_merges: bool = False,
-        special_token_types: tuple[str, ...] | None = None,
+        special_token_types: Iterable[str] | None = None,
         n_vocab: int | None = None,
     ):
         self.special_token_ids = {}

+ 1 - 0
gguf-py/pyproject.toml

@@ -21,6 +21,7 @@ classifiers = [
 [tool.poetry.dependencies]
 python = ">=3.8"
 numpy = ">=1.17"
+tqdm = ">=4.27"
 
 [tool.poetry.dev-dependencies]
 pytest = "^5.2"

+ 1 - 1
gguf-py/scripts/gguf-dump.py

@@ -47,7 +47,7 @@ def dump_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
         if len(field.types) == 1:
             curr_type = field.types[0]
             if curr_type == GGUFValueType.STRING:
-                log_message += ' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf8')[:60]))
+                log_message += ' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf-8')[:60]))
             elif field.types[0] in reader.gguf_scalar_to_np:
                 log_message += ' = {0}'.format(field.parts[-1][0])
         print(log_message)  # noqa: NP100

+ 6 - 6
gguf-py/scripts/gguf-new-metadata.py

@@ -7,7 +7,7 @@ import json
 from pathlib import Path
 
 import numpy as np
-from typing import Any, Mapping, Sequence
+from typing import Any, Sequence
 
 # Necessary to load the local gguf package
 if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
@@ -34,7 +34,7 @@ def get_byteorder(reader: gguf.GGUFReader) -> gguf.GGUFEndian:
         return host_endian
 
 
-def decode_field(field: gguf.ReaderField) -> Any:
+def decode_field(field: gguf.ReaderField | None) -> Any:
     if field and field.types:
         main_type = field.types[0]
 
@@ -42,11 +42,11 @@ def decode_field(field: gguf.ReaderField) -> Any:
             sub_type = field.types[-1]
 
             if sub_type == gguf.GGUFValueType.STRING:
-                return [str(bytes(field.parts[idx]), encoding='utf8') for idx in field.data]
+                return [str(bytes(field.parts[idx]), encoding='utf-8') for idx in field.data]
             else:
                 return [pv for idx in field.data for pv in field.parts[idx].tolist()]
         if main_type == gguf.GGUFValueType.STRING:
-            return str(bytes(field.parts[-1]), encoding='utf8')
+            return str(bytes(field.parts[-1]), encoding='utf-8')
         else:
             return field.parts[-1][0]
 
@@ -59,7 +59,7 @@ def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
     return decode_field(field)
 
 
-def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: Mapping[str, str], remove_metadata: Sequence[str]) -> None:
+def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: dict[str, str], remove_metadata: Sequence[str]) -> None:
     for field in reader.fields.values():
         # Suppress virtual fields and fields written by GGUFWriter
         if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
@@ -101,7 +101,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
 
     for tensor in reader.tensors:
         # Dimensions are written in reverse order, so flip them first
-        shape = np.flipud(tensor.shape)
+        shape = np.flipud(tensor.shape).tolist()
         writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
 
     writer.write_header_to_file()

+ 3 - 0
pyrightconfig.json

@@ -0,0 +1,3 @@
+{
+  "extraPaths": ["gguf-py"],
+}

+ 0 - 1
requirements/requirements-convert-hf-to-gguf-update.txt

@@ -1,3 +1,2 @@
 -r ./requirements-convert.txt
 torch~=2.1.1
-einops~=0.7.0

+ 0 - 1
requirements/requirements-convert-hf-to-gguf.txt

@@ -1,3 +1,2 @@
 -r ./requirements-convert.txt
 torch~=2.1.1
-einops~=0.7.0

+ 1 - 1
requirements/requirements-convert.txt

@@ -1,5 +1,5 @@
 numpy~=1.24.4
-sentencepiece~=0.1.98
+sentencepiece~=0.2.0
 transformers>=4.40.1,<5.0.0
 gguf>=0.1.0
 protobuf>=4.21.0,<5.0.0

Some files were not shown because too many files changed in this diff