Преглед изворни кода

convert.py : advanced option (#2753)

* Allow convert.py to convert to q8_0

Fix issue with bounded_parallel_map and greedy consuming iterator

Display elapsed time during conversion

* Add --concurrency option

Minor improvements to help text

Clean up bounded_parallel_map function a bit

* Massive speed improvement thanks to Cebtenzzre

* Refactor types
Kerfuffle пре 2 година
родитељ
комит
730d9c681e
1 измењених фајлова са 133 додато и 73 уклоњено
  1. 133 73
      convert.py

+ 133 - 73
convert.py

@@ -3,6 +3,7 @@
 import gguf
 import gguf
 import argparse
 import argparse
 import concurrent.futures
 import concurrent.futures
+from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
 import copy
 import copy
 import enum
 import enum
 import faulthandler
 import faulthandler
@@ -17,13 +18,14 @@ import re
 import signal
 import signal
 import struct
 import struct
 import sys
 import sys
+import time
 import zipfile
 import zipfile
 import numpy as np
 import numpy as np
 
 
 from abc import ABCMeta, abstractmethod
 from abc import ABCMeta, abstractmethod
 from dataclasses import dataclass
 from dataclasses import dataclass
 from pathlib import Path
 from pathlib import Path
-from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, TypeVar, Union)
+from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Set, Tuple, TypeVar, Union)
 from sentencepiece import SentencePieceProcessor  # type: ignore
 from sentencepiece import SentencePieceProcessor  # type: ignore
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
@@ -37,30 +39,70 @@ NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
 ARCH=gguf.MODEL_ARCH.LLAMA
 ARCH=gguf.MODEL_ARCH.LLAMA
 NAMES=gguf.MODEL_TENSOR_NAMES[ARCH]
 NAMES=gguf.MODEL_TENSOR_NAMES[ARCH]
 
 
+DEFAULT_CONCURRENCY = 8
 #
 #
 # data types
 # data types
 #
 #
 
 
 @dataclass(frozen=True)
 @dataclass(frozen=True)
-class UnquantizedDataType:
+class DataType:
     name: str
     name: str
+    dtype: 'np.dtype[Any]'
+    valid_conversions: List[str]
 
 
-DT_F16  = UnquantizedDataType('F16')
-DT_F32  = UnquantizedDataType('F32')
-DT_I32  = UnquantizedDataType('I32')
-DT_BF16 = UnquantizedDataType('BF16')
+    def elements_to_bytes(self, n_elements: int) -> int:
+        return n_elements * self.dtype.itemsize
 
 
-DataType = Union[UnquantizedDataType]
+@dataclass(frozen=True)
+class UnquantizedDataType(DataType):
+    pass
 
 
-DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
-    DT_BF16: np.dtype(np.uint16),
-    DT_F16:  np.dtype(np.float16),
-    DT_F32:  np.dtype(np.float32),
-    DT_I32:  np.dtype(np.int32),
-}
+DT_F16  = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
+DT_F32  = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
+DT_I32  = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
+DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
+
+@dataclass(frozen=True)
+class QuantizedDataType(DataType):
+    block_size: int
+    quantized_dtype: 'np.dtype[Any]'
+    ggml_type: gguf.GGMLQuantizationType
 
 
-NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = \
-    {dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()}
+    def quantize(self, arr: NDArray) -> NDArray:
+        raise NotImplementedError(f'Quantization for {self.name} not implemented')
+
+    def elements_to_bytes(self, n_elements: int) -> int:
+        assert n_elements % self.block_size == 0, f'Invalid number of elements {n_elements} for {self.name} with block size {self.block_size}'
+        return self.quantized_dtype.itemsize * (n_elements // self.block_size)
+
+@dataclass(frozen=True)
+class Q8_0QuantizedDataType(QuantizedDataType):
+    # Mini Q8_0 quantization in Python!
+    def quantize(self, arr: NDArray) -> NDArray:
+        assert arr.size % self.block_size == 0 and arr.size != 0, f'Bad array size {arr.size}'
+        assert arr.dtype == np.float32, f'Bad array type {arr.dtype}'
+        n_blocks = arr.size // self.block_size
+        blocks = arr.reshape((n_blocks, self.block_size))
+        # Much faster implementation of block quantization contributed by @Cebtenzzre
+        def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[Tuple[Any, Any]]:
+            d = abs(blocks).max(axis = 1) / np.float32(127)
+            with np.errstate(divide = 'ignore'):
+                qs = (blocks / d[:, None]).round()
+            qs[d == 0] = 0
+            yield from zip(d, qs)
+        return np.fromiter(quantize_blocks_q8_0(blocks), count = n_blocks, dtype = self.quantized_dtype)
+
+DT_Q8_0 = Q8_0QuantizedDataType('Q8_0',
+    dtype = np.dtype(np.float32), valid_conversions = [],
+    ggml_type = gguf.GGMLQuantizationType.Q8_0, block_size = 32,
+    quantized_dtype = np.dtype([('d', '<f2'), ('qs', 'i1', (32,))]))
+
+# Quantized types skipped here because they may also map to np.float32
+NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = {}
+for dt in (DT_BF16, DT_F16, DT_F32, DT_I32):
+    if dt.dtype in NUMPY_TYPE_TO_DATA_TYPE:
+        raise ValueError(f'Invalid duplicate data type {dt}')
+    NUMPY_TYPE_TO_DATA_TYPE[dt.dtype] = dt
 
 
 SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
 SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
     'BF16': DT_BF16,
     'BF16': DT_BF16,
@@ -73,20 +115,22 @@ SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
 # TODO: rename to LLAMAFileType
 # TODO: rename to LLAMAFileType
 # TODO: move to `gguf.py`
 # TODO: move to `gguf.py`
 class GGMLFileType(enum.IntEnum):
 class GGMLFileType(enum.IntEnum):
-    AllF32    = 0
-    MostlyF16 = 1  # except 1d tensors
+    AllF32     = 0
+    MostlyF16  = 1  # except 1d tensors
+    MostlyQ8_0 = 7  # except 1d tensors
 
 
     def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType:
     def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType:
-        if len(tensor.shape) == 1:
-            # 1D tensors are always F32.
-            return DT_F32
-        elif self == GGMLFileType.AllF32:
-            return DT_F32
-        elif self == GGMLFileType.MostlyF16:
-            return DT_F16
-        else:
+        dt = GGML_FILE_TYPE_TO_DATA_TYPE.get(self)
+        if dt is None:
             raise ValueError(self)
             raise ValueError(self)
+        # 1D tensors are always F32.
+        return dt if len(tensor.shape) > 1 else DT_F32
 
 
+GGML_FILE_TYPE_TO_DATA_TYPE: Dict[GGMLFileType, DataType] = {
+    GGMLFileType.AllF32    : DT_F32,
+    GGMLFileType.MostlyF16 : DT_F16,
+    GGMLFileType.MostlyQ8_0: DT_Q8_0,
+}
 
 
 #
 #
 # hparams loading
 # hparams loading
@@ -415,7 +459,7 @@ class UnquantizedTensor(Tensor):
         self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype]
         self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype]
 
 
     def astype(self, data_type: DataType) -> Tensor:
     def astype(self, data_type: DataType) -> Tensor:
-        dtype = DATA_TYPE_TO_NUMPY[data_type]
+        dtype = data_type.dtype
         if self.data_type == DT_BF16:
         if self.data_type == DT_BF16:
             self.ndarray = bf16_to_fp32(self.ndarray)
             self.ndarray = bf16_to_fp32(self.ndarray)
         return UnquantizedTensor(self.ndarray.astype(dtype))
         return UnquantizedTensor(self.ndarray.astype(dtype))
@@ -454,22 +498,6 @@ def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, conv
 GGMLCompatibleTensor = Union[UnquantizedTensor]
 GGMLCompatibleTensor = Union[UnquantizedTensor]
 
 
 
 
-class DeferredPermutedTensor(Tensor):
-    def __init__(self, base: Tensor, n_head: int, n_head_kv: int) -> None:
-        self.base = base
-        self.n_head = n_head
-        self.data_type = self.base.data_type
-
-    def astype(self, data_type: DataType) -> Tensor:
-        return self.base.astype(data_type).permute(self.n_head, self.n_head_kv)
-
-    def to_ggml(self) -> GGMLCompatibleTensor:
-        return self.base.to_ggml().permute(self.n_head, self.n_head_kv)
-
-    def permute(self, n_head: int, n_head_kv: int) -> Tensor:
-        raise Exception("shouldn't permute twice")
-
-
 @dataclass
 @dataclass
 class LazyTensor:
 class LazyTensor:
     _load: Callable[[], Tensor]
     _load: Callable[[], Tensor]
@@ -479,7 +507,9 @@ class LazyTensor:
 
 
     def load(self) -> Tensor:
     def load(self) -> Tensor:
         ret = self._load()
         ret = self._load()
-        assert ret.data_type == self.data_type, (self.data_type, ret.data_type, self.description)
+        # Should be okay if it maps to the same numpy type?
+        assert ret.data_type == self.data_type or (self.data_type.dtype == ret.data_type.dtype), \
+                (self.data_type, ret.data_type, self.description)
         return ret
         return ret
 
 
     def astype(self, data_type: DataType) -> 'LazyTensor':
     def astype(self, data_type: DataType) -> 'LazyTensor':
@@ -490,8 +520,8 @@ class LazyTensor:
         return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}')
         return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}')
 
 
     def validate_conversion_to(self, data_type: DataType) -> None:
     def validate_conversion_to(self, data_type: DataType) -> None:
-        if data_type == self.data_type:
-            return
+        if data_type != self.data_type and data_type.name not in self.data_type.valid_conversions:
+            raise ValueError(f'Cannot validate conversion from {self.data_type} to {data_type}.')
 
 
 
 
 LazyModel = Dict[str, LazyTensor]
 LazyModel = Dict[str, LazyTensor]
@@ -617,9 +647,7 @@ class LazyUnpickler(pickle.Unpickler):
         info = self.zip_file.getinfo(filename)
         info = self.zip_file.getinfo(filename)
 
 
         def load(offset: int, elm_count: int) -> NDArray:
         def load(offset: int, elm_count: int) -> NDArray:
-            dtype = DATA_TYPE_TO_NUMPY.get(data_type)
-            if dtype is None:
-                raise Exception("tensor stored in unsupported format")
+            dtype = data_type.dtype
             fp = self.zip_file.open(info)
             fp = self.zip_file.open(info)
             fp.seek(offset * dtype.itemsize)
             fp.seek(offset * dtype.itemsize)
             size = elm_count * dtype.itemsize
             size = elm_count * dtype.itemsize
@@ -683,7 +711,7 @@ def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
 
 
     def convert(info: Dict[str, Any]) -> LazyTensor:
     def convert(info: Dict[str, Any]) -> LazyTensor:
         data_type = SAFETENSORS_DATA_TYPES[info['dtype']]
         data_type = SAFETENSORS_DATA_TYPES[info['dtype']]
-        numpy_dtype = DATA_TYPE_TO_NUMPY[data_type]
+        numpy_dtype = data_type.dtype
         shape: List[int] = info['shape']
         shape: List[int] = info['shape']
         begin, end = info['data_offsets']
         begin, end = info['data_offsets']
         assert 0 <= begin <= end <= len(byte_buf)
         assert 0 <= begin <= end <= len(byte_buf)
@@ -723,23 +751,35 @@ def lazy_load_file(path: Path) -> ModelPlus:
 In = TypeVar('In')
 In = TypeVar('In')
 Out = TypeVar('Out')
 Out = TypeVar('Out')
 
 
-def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int) -> Iterable[Out]:
+def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, factory: Callable = ThreadPoolExecutor) -> Iterable[Out]:
     '''Parallel map, but with backpressure.  If the caller doesn't call `next`
     '''Parallel map, but with backpressure.  If the caller doesn't call `next`
     fast enough, this will stop calling `func` at some point rather than
     fast enough, this will stop calling `func` at some point rather than
     letting results pile up in memory.  Specifically, there is a max of one
     letting results pile up in memory.  Specifically, there is a max of one
     output value buffered per thread.'''
     output value buffered per thread.'''
-    with concurrent.futures.ThreadPoolExecutor() as executor:
+    if concurrency < 2:
+        yield from map(func, iterable)
+        # Not reached.
+    iterable = iter(iterable)
+    with factory(max_workers = max_workers) as executor:
         futures: List[concurrent.futures.Future[Out]] = []
         futures: List[concurrent.futures.Future[Out]] = []
-        items_rev = list(iterable)[::-1]
-        for i in range(min(concurrency, len(items_rev))):
-            futures.append(executor.submit(func, items_rev.pop()))
+        done = False
+        for _ in range(concurrency):
+            try:
+                futures.append(executor.submit(func, next(iterable)))
+            except StopIteration:
+                done = True
+                break
+
         while futures:
         while futures:
             result = futures.pop(0).result()
             result = futures.pop(0).result()
-            if items_rev:
-                futures.append(executor.submit(func, items_rev.pop()))
+            while not done and len(futures) < concurrency:
+                try:
+                    futures.append(executor.submit(func, next(iterable)))
+                except StopIteration:
+                    done = True
+                    break
             yield result
             yield result
 
 
-
 def check_vocab_size(params: Params, vocab: Vocab) -> None:
 def check_vocab_size(params: Params, vocab: Vocab) -> None:
     if params.n_vocab != vocab.vocab_size:
     if params.n_vocab != vocab.vocab_size:
         assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab)
         assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab)
@@ -804,12 +844,11 @@ class OutputFile:
         self.gguf.add_token_types(toktypes)
         self.gguf.add_token_types(toktypes)
 
 
     def add_tensor_info(self, name: str, tensor: LazyTensor) -> None:
     def add_tensor_info(self, name: str, tensor: LazyTensor) -> None:
-        n_elements = 1
-        for dim in tensor.shape:
-            n_elements *= dim
-        data_type = DATA_TYPE_TO_NUMPY[tensor.data_type]
-        data_nbytes = n_elements * data_type.itemsize
-        self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes)
+        n_elements = int(np.prod(tensor.shape))
+        raw_dtype = getattr(tensor.data_type, 'ggml_type', None)
+        data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype
+        data_nbytes = tensor.data_type.elements_to_bytes(n_elements)
+        self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype = raw_dtype)
 
 
     def write_meta(self) -> None:
     def write_meta(self) -> None:
         self.gguf.write_header_to_file()
         self.gguf.write_header_to_file()
@@ -835,7 +874,20 @@ class OutputFile:
         of.close()
         of.close()
 
 
     @staticmethod
     @staticmethod
-    def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -> None:
+    def do_item(item: Tuple[str, LazyTensor]) -> Tuple[DataType, NDArray]:
+        name, lazy_tensor = item
+        tensor = lazy_tensor.load().to_ggml()
+        return (lazy_tensor.data_type, tensor.ndarray)
+
+    @staticmethod
+    def maybe_do_quantize(item: Tuple[DataType, NDArray]) -> NDArray:
+        dt, arr = item
+        if not isinstance(dt, QuantizedDataType):
+            return arr
+        return dt.quantize(arr)
+
+    @staticmethod
+    def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, concurrency: int = DEFAULT_CONCURRENCY) -> None:
         check_vocab_size(params, vocab)
         check_vocab_size(params, vocab)
 
 
         of = OutputFile(fname_out)
         of = OutputFile(fname_out)
@@ -851,16 +903,19 @@ class OutputFile:
         of.write_meta()
         of.write_meta()
         of.write_tensor_info()
         of.write_tensor_info()
 
 
-        def do_item(item: Tuple[str, LazyTensor]) -> NDArray:
-            name, lazy_tensor = item
-            return lazy_tensor.load().to_ggml().ndarray
-
         # tensor data
         # tensor data
-        ndarrays = bounded_parallel_map(do_item, model.items(), concurrency=8)
+        ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency)
+        if ftype == GGMLFileType.MostlyQ8_0:
+            ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, factory = ProcessPoolExecutor)
+        else:
+            ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner)
+
+        start = time.time()
         for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
         for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
+            elapsed = time.time() - start
             size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
             size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
             padi = len(str(len(model)))
             padi = len(str(len(model)))
-            print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type}")
+            print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}")
             of.gguf.write_tensor_data(ndarray)
             of.gguf.write_tensor_data(ndarray)
 
 
         of.close()
         of.close()
@@ -872,6 +927,8 @@ def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFi
         return GGMLFileType.AllF32
         return GGMLFileType.AllF32
     if output_type_str == "f16" or (output_type_str is None and wq_type in (DT_F16, DT_BF16)):
     if output_type_str == "f16" or (output_type_str is None and wq_type in (DT_F16, DT_BF16)):
         return GGMLFileType.MostlyF16
         return GGMLFileType.MostlyF16
+    if output_type_str == "q8_0":
+        return GGMLFileType.MostlyQ8_0
 
 
     name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()}
     name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()}
 
 
@@ -918,7 +975,7 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
             print(f"skipping tensor {name_new}")
             print(f"skipping tensor {name_new}")
             continue
             continue
         else:
         else:
-            print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type} | {lazy_tensor.shape}")
+            print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}")
             out[name_new] = lazy_tensor
             out[name_new] = lazy_tensor
 
 
     return out
     return out
@@ -1023,6 +1080,7 @@ def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path:
     namestr = {
     namestr = {
         GGMLFileType.AllF32:    "f32",
         GGMLFileType.AllF32:    "f32",
         GGMLFileType.MostlyF16: "f16",
         GGMLFileType.MostlyF16: "f16",
+        GGMLFileType.MostlyQ8_0:"q8_0",
     }[file_type]
     }[file_type]
     ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf"
     ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf"
     if ret in model_paths:
     if ret in model_paths:
@@ -1046,12 +1104,13 @@ def main(args_in: Optional[List[str]] = None) -> None:
     parser.add_argument("--dump",        action="store_true",    help="don't convert, just show what's in the model")
     parser.add_argument("--dump",        action="store_true",    help="don't convert, just show what's in the model")
     parser.add_argument("--dump-single", action="store_true",    help="don't convert, just show what's in a single model file")
     parser.add_argument("--dump-single", action="store_true",    help="don't convert, just show what's in a single model file")
     parser.add_argument("--vocab-only",  action="store_true",    help="extract only the vocab")
     parser.add_argument("--vocab-only",  action="store_true",    help="extract only the vocab")
-    parser.add_argument("--outtype",     choices=["f32", "f16"], help="output format (default: based on input)")
+    parser.add_argument("--outtype",     choices=["f32", "f16", "q8_0"], help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
     parser.add_argument("--vocab-dir",   type=Path,              help="directory containing tokenizer.model, if separate from model file")
     parser.add_argument("--vocab-dir",   type=Path,              help="directory containing tokenizer.model, if separate from model file")
     parser.add_argument("--outfile",     type=Path,              help="path to write to; default: based on input")
     parser.add_argument("--outfile",     type=Path,              help="path to write to; default: based on input")
     parser.add_argument("model",         type=Path,              help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
     parser.add_argument("model",         type=Path,              help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
     parser.add_argument("--vocabtype",   choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm")
     parser.add_argument("--vocabtype",   choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm")
     parser.add_argument("--ctx",         type=int,               help="model training context (default: based on input)")
     parser.add_argument("--ctx",         type=int,               help="model training context (default: based on input)")
+    parser.add_argument("--concurrency", type=int,               help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
     args = parser.parse_args(args_in)
     args = parser.parse_args(args_in)
 
 
     if args.dump_single:
     if args.dump_single:
@@ -1073,6 +1132,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
         params.ftype = {
         params.ftype = {
             "f32": GGMLFileType.AllF32,
             "f32": GGMLFileType.AllF32,
             "f16": GGMLFileType.MostlyF16,
             "f16": GGMLFileType.MostlyF16,
+            "q8_0": GGMLFileType.MostlyQ8_0,
         }[args.outtype]
         }[args.outtype]
 
 
     print(f"params = {params}")
     print(f"params = {params}")
@@ -1104,7 +1164,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
         params.ftype = ftype
         params.ftype = ftype
         print(f"Writing {outfile}, format {ftype}")
         print(f"Writing {outfile}, format {ftype}")
 
 
-        OutputFile.write_all(outfile, params, model, vocab)
+        OutputFile.write_all(outfile, ftype, params, model, vocab, concurrency = args.concurrency)
         print(f"Wrote {outfile}")
         print(f"Wrote {outfile}")