Forráskód Böngészése

Refactor lora adapter support (#8332)

* lora: load to devide buft

* add patch tensor function

* correct tensor patch

* llama_lora_adapter_apply

* correct ggml_backend_tensor_copy

* add llm_build_mm

* fix auto merge

* update based on review comments

* add convert script

* no more transpose A

* add f16 convert

* add metadata check

* add sanity check

* fix ftype

* add requirements

* fix requirements

* fix outfile

* conversion: only allow selected models

* fix types

* cuda : do not use dmmv if the tensor does not have enough cols

* llama : lora fixes

* do not disable mmap with lora

Co-authored-by: slaren <slarengh@gmail.com>

* llm_build_lora_mm_id

* convert_lora : MoE LoRA conversion support

* convert_lora : prefer safetensors, similarly to convert_hf

* convert_hf : simplify modify_tensors for InternLM2

* convert_lora : lazy conversion

* llama : load and use alpha from LoRA adapters

* llama : use llm_build_lora_mm in most model graphs

* auto scale

* Revert "auto scale"

This reverts commit 42415a4874e0f963e4aca6796ea5dfb97cd17464.

* remove redundant params

* Apply suggestions from code review

Co-authored-by: slaren <slarengh@gmail.com>

* change kv metadata

* move add_type to __init__

* convert_hf : move add_type to main()

* convert_lora : use the GGUFWriter from Model instead of overwriting it

---------

Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: Francis Couture-Harpin <git@compilade.net>
Xuan Son Nguyen 1 éve
szülő
commit
97bdd26eee

+ 3 - 10
common/common.cpp

@@ -685,7 +685,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
     if (arg == "--lora") {
         CHECK_ARG
         params.lora_adapter.emplace_back(argv[i], 1.0f);
-        params.use_mmap = false;
         return true;
     }
     if (arg == "--lora-scaled") {
@@ -693,7 +692,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         const char* lora_adapter = argv[i];
         CHECK_ARG
         params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i]));
-        params.use_mmap = false;
         return true;
     }
     if (arg == "--lora-base") {
@@ -2089,19 +2087,14 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
     for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
         const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
         float lora_scale = std::get<1>(params.lora_adapter[i]);
-        int err = llama_model_apply_lora_from_file(model,
-                                             lora_adapter.c_str(),
-                                             lora_scale,
-                                             ((i > 0) || params.lora_base.empty())
-                                                ? NULL
-                                                : params.lora_base.c_str(),
-                                             params.n_threads);
-        if (err != 0) {
+        auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str());
+        if (adapter == nullptr) {
             fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
             llama_free(lctx);
             llama_free_model(model);
             return std::make_tuple(nullptr, nullptr);
         }
+        llama_lora_adapter_set(lctx, adapter, lora_scale);
     }
 
     if (params.ignore_eos) {

+ 12 - 22
convert_hf_to_gguf.py

@@ -2264,13 +2264,6 @@ class InternLM2Model(Model):
 
         special_vocab.add_to_gguf(self.gguf_writer)
 
-    def _hf_permute_qk(self, weights, n_head: int, n_head_kv: int):
-        if n_head_kv is not None and n_head != n_head_kv:
-            n_head = n_head_kv
-        return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
-                .swapaxes(1, 2)
-                .reshape(weights.shape))
-
     def set_gguf_parameters(self):
         self.gguf_writer.add_name("InternLM2")
         self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
@@ -2290,26 +2283,22 @@ class InternLM2Model(Model):
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
         num_heads = self.hparams["num_attention_heads"]
         num_kv_heads = self.hparams["num_key_value_heads"]
-        hidden_size = self.hparams["hidden_size"]
+        n_embd = self.hparams["hidden_size"]
         q_per_kv = num_heads // num_kv_heads
-        head_dim = hidden_size // num_heads
+        head_dim = n_embd // num_heads
         num_groups = num_heads // q_per_kv
 
-        qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
-
-        if re.match(qkv_pattern, name):
-            bid = re.findall(qkv_pattern, name)[0]
+        if bid is not None and f"model.layers.{bid}.attention.wqkv" in name:
             qkv = data_torch
-            # qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
-            qkv = qkv.T.reshape((-1, num_groups, q_per_kv + 2, head_dim))
-            q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
+
+            qkv = qkv.reshape((num_groups, q_per_kv + 2, head_dim, n_embd))
+            q, k, v = qkv[:, : q_per_kv], qkv[:, -2], qkv[:, -1]
+
             # The model weights of q and k equire additional reshape.
-            # q = self._hf_permute_qk(rearrange(q, " o g n i ->  o (g n i)").T, num_heads, num_heads)
-            q = self._hf_permute_qk(q.reshape((q.shape[0], -1)).T, num_heads, num_heads)
-            # k = self._hf_permute_qk(rearrange(k, " o g n i ->  o (g n i)").T, num_heads, num_kv_heads)
-            k = self._hf_permute_qk(k.reshape((k.shape[0], -1)).T, num_heads, num_kv_heads)
-            # v = rearrange(v, " o g n i ->  o (g n i)").T
-            v = v.reshape((v.shape[0], -1)).T
+            q = LlamaModel.permute(q.reshape((-1, q.shape[-1])), num_heads, num_heads)
+            k = LlamaModel.permute(k.reshape((-1, k.shape[-1])), num_heads, num_kv_heads)
+            v = v.reshape((-1, v.shape[-1]))
+
             return [
                 (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q),
                 (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k),
@@ -3585,6 +3574,7 @@ def main() -> None:
                                      small_first_shard=args.no_tensor_first_split)
 
         logger.info("Set model parameters")
+        model_instance.gguf_writer.add_type(gguf.GGUFType.MODEL)
         model_instance.set_gguf_parameters()
 
         logger.info("Set model tokenizer")

+ 374 - 0
convert_lora_to_gguf.py

@@ -0,0 +1,374 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+import logging
+import argparse
+import os
+import sys
+import json
+from math import prod
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
+
+import torch
+
+if TYPE_CHECKING:
+    from torch import Tensor
+
+if 'NO_LOCAL_GGUF' not in os.environ:
+    sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
+import gguf
+
+# reuse model definitions from convert_hf_to_gguf.py
+from convert_hf_to_gguf import LazyTorchTensor, Model
+
+logger = logging.getLogger("lora-to-gguf")
+
+
+@dataclass
+class PartialLoraTensor:
+    A: Tensor | None = None
+    B: Tensor | None = None
+
+
+# magic to support tensor shape modifications and splitting
+class LoraTorchTensor:
+    _lora_A: Tensor  # (n_rank, row_size)
+    _lora_B: Tensor  # (col_size, n_rank)
+    _rank: int
+
+    def __init__(self, A: Tensor, B: Tensor):
+        assert len(A.shape) == len(B.shape)
+        assert A.shape[-2] == B.shape[-1]
+        if A.dtype != B.dtype:
+            A = A.to(torch.float32)
+            B = B.to(torch.float32)
+        self._lora_A = A
+        self._lora_B = B
+        self._rank = B.shape[-1]
+
+    def get_lora_A_B(self) -> tuple[Tensor, Tensor]:
+        return (self._lora_A, self._lora_B)
+
+    def __getitem__(
+        self,
+        indices: (
+            SupportsIndex
+            | slice
+            | tuple[SupportsIndex | slice | Tensor, ...]  # TODO: add ellipsis in the type signature
+        ),
+    ) -> LoraTorchTensor:
+        shape = self.shape
+        if isinstance(indices, SupportsIndex):
+            if len(shape) > 2:
+                return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
+            else:
+                raise NotImplementedError  # can't return a vector
+        elif isinstance(indices, slice):
+            if len(shape) > 2:
+                return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
+            else:
+                return LoraTorchTensor(self._lora_A, self._lora_B[indices])
+        elif isinstance(indices, tuple):
+            assert len(indices) > 0
+            if indices[-1] is Ellipsis:
+                return self[indices[:-1]]
+            # expand ellipsis
+            indices = tuple(
+                u
+                for v in (
+                    (
+                        (slice(None, None) for _ in range(len(indices) - 1))
+                        if i is Ellipsis
+                        else (i,)
+                    )
+                    for i in indices
+                )
+                for u in v
+            )
+
+            if len(indices) < len(shape):
+                indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape))))
+
+            # TODO: make sure this is correct
+            indices_A = (
+                *(
+                    (
+                        j.__index__() % self._lora_A.shape[i]
+                        if isinstance(j, SupportsIndex)
+                        else slice(None, None)
+                    )
+                    for i, j in enumerate(indices[:-2])
+                ),
+                slice(None, None),
+                indices[-1],
+            )
+            indices_B = indices[:-1]
+            return LoraTorchTensor(self._lora_A[indices_A], self._lora_B[indices_B])
+        else:
+            raise NotImplementedError  # unknown indice type
+
+    @property
+    def dtype(self) -> torch.dtype:
+        assert self._lora_A.dtype == self._lora_B.dtype
+        return self._lora_A.dtype
+
+    @property
+    def shape(self) -> tuple[int, ...]:
+        assert len(self._lora_A.shape) == len(self._lora_B.shape)
+        return (*self._lora_B.shape[:-1], self._lora_A.shape[-1])
+
+    def size(self, dim=None):
+        assert dim is None
+        return self.shape
+
+    def reshape(self, *shape: int | tuple[int, ...]) -> LoraTorchTensor:
+        if isinstance(shape[0], tuple):
+            new_shape: tuple[int, ...] = shape[0]
+        else:
+            new_shape = cast(tuple[int, ...], shape)
+        orig_shape = self.shape
+        if len(new_shape) < 2:
+            raise NotImplementedError  # can't become a vector
+
+        # expand -1 in the shape
+        if any(dim == -1 for dim in new_shape):
+            n_elems = prod(orig_shape)
+            n_new_elems = prod(dim if dim != -1 else 1 for dim in new_shape)
+            assert n_elems % n_new_elems == 0
+            new_shape = (*(dim if dim != -1 else n_elems // n_new_elems for dim in new_shape),)
+
+        if new_shape[-1] != orig_shape[-1]:
+            raise NotImplementedError  # can't reshape the row size trivially
+
+        shape_A = (*(1 for _ in new_shape[:-2]), self._rank, orig_shape[-1])
+        shape_B = (*new_shape[:-1], self._rank)
+        return LoraTorchTensor(
+            self._lora_A.reshape(shape_A),
+            self._lora_B.reshape(shape_B),
+        )
+
+    def reshape_as(self, other: Tensor) -> LoraTorchTensor:
+        return self.reshape(*other.shape)
+
+    def view(self, *size: int) -> LoraTorchTensor:
+        return self.reshape(*size)
+
+    def permute(self, *dims: int) -> LoraTorchTensor:
+        shape = self.shape
+        dims = tuple(dim - len(shape) if dim >= 0 else dim for dim in dims)
+        if dims[-1] == -1:
+            # TODO: support higher dimensional A shapes bigger than 1
+            assert all(dim == 1 for dim in self._lora_A.shape[:-2])
+            return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims))
+        if len(shape) == 2 and dims[-1] == -2 and dims[-2] == -1:
+            return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims))
+        else:
+            # TODO: compose the above two
+            raise NotImplementedError
+
+    def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor:
+        shape = self.shape
+        dims = [i for i in range(len(shape))]
+        dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
+        return self.permute(*dims)
+
+    def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor:
+        return self.transpose(axis0, axis1)
+
+    def to(self, *args, **kwargs):
+        return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs))
+
+    @classmethod
+    def __torch_function__(cls, func: Callable, types, args=(), kwargs=None):
+        del types  # unused
+
+        if kwargs is None:
+            kwargs = {}
+
+        if func is torch.permute:
+            return type(args[0]).permute(*args, **kwargs)
+        elif func is torch.reshape:
+            return type(args[0]).reshape(*args, **kwargs)
+        elif func is torch.stack:
+            assert isinstance(args[0], Sequence)
+            dim = kwargs.get("dim", 0)
+            assert dim == 0
+            return LoraTorchTensor(
+                torch.stack([a._lora_A for a in args[0]], dim),
+                torch.stack([b._lora_B for b in args[0]], dim),
+            )
+        elif func is torch.cat:
+            assert isinstance(args[0], Sequence)
+            dim = kwargs.get("dim", 0)
+            assert dim == 0
+            if len(args[0][0].shape) > 2:
+                return LoraTorchTensor(
+                    torch.cat([a._lora_A for a in args[0]], dim),
+                    torch.cat([b._lora_B for b in args[0]], dim),
+                )
+            elif all(torch.equal(args[0][0]._lora_A, t._lora_A) for t in args[0][1:]):
+                return LoraTorchTensor(
+                    args[0][0]._lora_A,
+                    torch.cat([b._lora_B for b in args[0]], dim),
+                )
+            else:
+                raise NotImplementedError
+        else:
+            raise NotImplementedError
+
+
+def get_base_tensor_name(lora_tensor_name: str) -> str:
+    base_name = lora_tensor_name.replace("base_model.model.", "")
+    base_name = base_name.replace(".lora_A.weight", ".weight")
+    base_name = base_name.replace(".lora_B.weight", ".weight")
+    return base_name
+
+
+def parse_args() -> argparse.Namespace:
+    parser = argparse.ArgumentParser(
+        description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file")
+    parser.add_argument(
+        "--outfile", type=Path,
+        help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
+    )
+    parser.add_argument(
+        "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
+        help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
+    )
+    parser.add_argument(
+        "--bigendian", action="store_true",
+        help="model is executed on big endian machine",
+    )
+    parser.add_argument(
+        "--no-lazy", action="store_true",
+        help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)",
+    )
+    parser.add_argument(
+        "--verbose", action="store_true",
+        help="increase output verbosity",
+    )
+    parser.add_argument(
+        "--base", type=Path, required=True,
+        help="directory containing base model file",
+    )
+    parser.add_argument(
+        "lora_path", type=Path,
+        help="directory containing LoRA adapter file",
+    )
+
+    return parser.parse_args()
+
+
+if __name__ == '__main__':
+    args = parse_args()
+    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
+
+    ftype_map: dict[str, gguf.LlamaFileType] = {
+        "f32": gguf.LlamaFileType.ALL_F32,
+        "f16": gguf.LlamaFileType.MOSTLY_F16,
+        "bf16": gguf.LlamaFileType.MOSTLY_BF16,
+        "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
+        "auto": gguf.LlamaFileType.GUESSED,
+    }
+
+    ftype = ftype_map[args.outtype]
+
+    dir_base_model: Path = args.base
+    dir_lora: Path = args.lora_path
+    lora_config = dir_lora / "adapter_config.json"
+    input_model = dir_lora / "adapter_model.safetensors"
+
+    if args.outfile is not None:
+        fname_out = args.outfile
+    else:
+        # output in the same directory as the model by default
+        fname_out = dir_lora / 'ggml-lora-{ftype}.gguf'
+
+    if os.path.exists(input_model):
+        # lazy import load_file only if lora is in safetensors format.
+        from safetensors.torch import load_file
+
+        lora_model = load_file(input_model, device="cpu")
+    else:
+        input_model = os.path.join(dir_lora, "adapter_model.bin")
+        lora_model = torch.load(input_model, map_location="cpu", weights_only=True)
+
+    # load base model
+    logger.info(f"Loading base model: {dir_base_model.name}")
+    hparams = Model.load_hparams(dir_base_model)
+    with torch.inference_mode():
+        try:
+            model_class = Model.from_model_architecture(hparams["architectures"][0])
+        except NotImplementedError:
+            logger.error(f"Model {hparams['architectures'][0]} is not supported")
+            sys.exit(1)
+
+        class LoraModel(model_class):
+            model_arch = model_class.model_arch
+
+            def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
+                tensor_map: dict[str, PartialLoraTensor] = {}
+
+                for name, tensor in lora_model.items():
+                    if self.lazy:
+                        tensor = LazyTorchTensor.from_eager(tensor)
+                    base_name = get_base_tensor_name(name)
+                    is_lora_a = ".lora_A.weight" in name
+                    is_lora_b = ".lora_B.weight" in name
+                    if not is_lora_a and not is_lora_b:
+                        if ".base_layer.weight" in name:
+                            continue
+                        logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
+                        sys.exit(1)
+
+                    if base_name in tensor_map:
+                        if is_lora_a:
+                            tensor_map[base_name].A = tensor
+                        else:
+                            tensor_map[base_name].B = tensor
+                    else:
+                        if is_lora_a:
+                            tensor_map[base_name] = PartialLoraTensor(A=tensor)
+                        else:
+                            tensor_map[base_name] = PartialLoraTensor(B=tensor)
+
+                for name, tensor in tensor_map.items():
+                    assert tensor.A is not None
+                    assert tensor.B is not None
+                    yield (name, cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B)))
+
+            def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+                dest = super().modify_tensors(data_torch, name, bid)
+                for dest_name, dest_data in dest:
+                    assert isinstance(dest_data, LoraTorchTensor)
+                    lora_a, lora_b = dest_data.get_lora_A_B()
+
+                    yield (dest_name + ".lora_a", lora_a)
+                    yield (dest_name + ".lora_b", lora_b)
+
+        model_instance = LoraModel(
+            dir_base_model,
+            ftype,
+            fname_out,
+            is_big_endian=args.bigendian,
+            use_temp_file=False,
+            eager=args.no_lazy,
+            model_name=None,
+        )
+
+        with open(lora_config, "r") as f:
+            lparams: dict[str, Any] = json.load(f)
+
+        alpha = lparams["lora_alpha"]
+
+        model_instance.gguf_writer.add_string(gguf.Keys.General.TYPE, gguf.GGUFType.ADAPTER)
+        model_instance.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
+        model_instance.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, float(alpha))
+        model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
+        logger.info("Exporting model...")
+        model_instance.write()
+        logger.info(f"Model successfully exported to {model_instance.fname_out}")

+ 2 - 1
ggml/src/ggml-cuda.cu

@@ -1876,7 +1876,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
 
     bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
-        && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1;
+        && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[0] >= GGML_CUDA_DMMV_X*2
+        && src1->ne[1] == 1;
     bool          use_mul_mat_vec_q =  ggml_is_quantized(src0->type)
         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
         && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;

+ 2 - 2
ggml/src/ggml.c

@@ -19478,7 +19478,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
 
     fprintf(fp, "digraph G {\n");
     fprintf(fp, "  newrank = true;\n");
-    fprintf(fp, "  rankdir = LR;\n");
+    fprintf(fp, "  rankdir = TB;\n");
 
     for (int i = 0; i < gb->n_nodes; i++) {
         struct ggml_tensor * node = gb->nodes[i];
@@ -19540,7 +19540,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
         }
 
         fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]);
-        if (ggml_nelements(node) < 5) {
+        if (ggml_nelements(node) < 5 && node->data != NULL) {
             fprintf(fp, " | (");
             for (int j = 0; j < ggml_nelements(node); j++) {
                 if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {

+ 10 - 0
gguf-py/gguf/constants.py

@@ -19,6 +19,7 @@ GGML_QUANT_VERSION     = 2  # GGML_QNT_VERSION from ggml.h
 
 class Keys:
     class General:
+        TYPE                 = "general.type"
         ARCHITECTURE         = "general.architecture"
         QUANTIZATION_VERSION = "general.quantization_version"
         ALIGNMENT            = "general.alignment"
@@ -120,11 +121,20 @@ class Keys:
         MIDDLE_ID            = "tokenizer.ggml.middle_token_id"
         EOT_ID               = "tokenizer.ggml.eot_token_id"
 
+    class Adapter:
+        TYPE       = "adapter.type"
+        LORA_ALPHA = "adapter.lora.alpha"
+
 #
 # recommended mapping of model tensor names for storage in gguf
 #
 
 
+class GGUFType:
+    MODEL   = "model"
+    ADAPTER = "adapter"
+
+
 class MODEL_ARCH(IntEnum):
     LLAMA        = auto()
     FALCON       = auto()

+ 3 - 0
gguf-py/gguf/gguf_writer.py

@@ -424,6 +424,9 @@ class GGUFWriter:
                 fout.close()
             self.fout = None
 
+    def add_type(self, type_name: str) -> None:
+        self.add_string(Keys.General.TYPE, type_name)
+
     def add_architecture(self) -> None:
         self.add_string(Keys.General.ARCHITECTURE, self.arch)
 

+ 1 - 1
gguf-py/gguf/quants.py

@@ -43,7 +43,7 @@ def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.
         osize *= dim
     out = np.empty(shape=osize, dtype=otype)
     # compute over groups of 16 rows (arbitrary, but seems good for performance)
-    n_groups = rows.shape[0] // 16
+    n_groups = (rows.shape[0] // 16) or 1
     np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out)
     return out.reshape(oshape)
 

+ 25 - 12
include/llama.h

@@ -411,6 +411,9 @@ extern "C" {
         const char * content;
     } llama_chat_message;
 
+    // lora adapter
+    struct llama_lora_adapter;
+
     // Helpers for getting default parameters
     LLAMA_API struct llama_model_params llama_model_default_params(void);
     LLAMA_API struct llama_context_params llama_context_default_params(void);
@@ -510,18 +513,28 @@ extern "C" {
             const char * fname_out,
             const llama_model_quantize_params * params);
 
-    // Apply a LoRA adapter to a loaded model
-    // path_base_model is the path to a higher quality model to use as a base for
-    // the layers modified by the adapter. Can be NULL to use the current loaded model.
-    // The model needs to be reloaded before applying a new adapter, otherwise the adapter
-    // will be applied on top of the previous one
-    // Returns 0 on success
-    LLAMA_API int32_t llama_model_apply_lora_from_file(
-            const struct llama_model * model,
-                          const char * path_lora,
-                               float   scale,
-                          const char * path_base_model,
-                             int32_t   n_threads);
+    // Load a LoRA adapter from file
+    // The loaded adapter will be associated to the given model, and will be free when the model is deleted
+    LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init(
+            struct llama_model * model,
+            const char * path_lora);
+
+    // Add a loaded LoRA adapter to given context
+    // This will not modify model's weight
+    LLAMA_API int32_t llama_lora_adapter_set(
+            struct llama_context * ctx,
+            struct llama_lora_adapter * adapter,
+            float scale);
+
+    // Remove a LoRA adapter from given context
+    // Return -1 if the adapter is not present in the context
+    LLAMA_API int32_t llama_lora_adapter_remove(
+            struct llama_context * ctx,
+            struct llama_lora_adapter * adapter);
+
+    // Manually free a LoRA adapter
+    // Note: loaded adapters will be free when the associated model is deleted
+    LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter);
 
     // Apply a loaded control vector to a llama_context, or if data is NULL, clear
     // the currently loaded vector.

+ 1 - 0
requirements.txt

@@ -9,3 +9,4 @@
 -r ./requirements/requirements-convert_hf_to_gguf.txt
 -r ./requirements/requirements-convert_hf_to_gguf_update.txt
 -r ./requirements/requirements-convert_llama_ggml_to_gguf.txt
+-r ./requirements/requirements-convert_lora_to_gguf.txt

+ 2 - 0
requirements/requirements-convert_lora_to_gguf.txt

@@ -0,0 +1,2 @@
+-r ./requirements-convert_hf_to_gguf.txt
+--extra-index-url https://download.pytorch.org/whl/cpu

A különbségek nem kerülnek megjelenítésre, a fájl túl nagy
+ 216 - 100
src/llama.cpp


Nem az összes módosított fájl került megjelenítésre, mert túl sok fájl változott