Преглед изворни кода

Inference support for T5 and FLAN-T5 model families (#5763)

* llama : add inference support and model types for T5 and FLAN-T5 model families

* llama : add new API functions to support encoder-decoder models: llama_encode(), llama_model_has_encoder(), llama_model_decoder_start_token()

* common, llama-cli, llama-batched : add support for encoder-decoder models

* convert-hf : handle shared token embeddings tensors in T5Model

* convert-hf : add support for SentencePiece BPE tokenizer in T5Model (for Pile-T5 models)

* convert-hf : add MT5ForConditionalGeneration and UMT5ForConditionalGeneration to architectures supported by T5Model

* convert : add t5 tokenizer tests, use "slow" HF tokenizer for t5

---------

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
fairydreaming пре 1 година
родитељ
комит
807b0c49ff

+ 18 - 1
common/common.cpp

@@ -2070,7 +2070,24 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
     if (params.warmup) {
         LOG("warming up the model with an empty run\n");
 
-        std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
+        std::vector<llama_token> tmp;
+        llama_token bos = llama_token_bos(model);
+        llama_token eos = llama_token_eos(model);
+        // some models (e.g. T5) don't have a BOS token
+        if (bos != -1) {
+            tmp.push_back(bos);
+        }
+        tmp.push_back(eos);
+
+        if (llama_model_has_encoder(model)) {
+            llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
+            llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
+            if (decoder_start_token_id == -1) {
+                decoder_start_token_id = bos;
+            }
+            tmp.clear();
+            tmp.push_back(decoder_start_token_id);
+        }
         llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
         llama_kv_cache_clear(lctx);
         llama_synchronize(lctx);

+ 16 - 3
convert-hf-to-gguf-update.py

@@ -45,6 +45,7 @@ class TOKENIZER_TYPE(IntEnum):
     SPM = auto()
     BPE = auto()
     WPM = auto()
+    UGM = auto()
 
 
 # TODO: this string has to exercise as much pre-tokenizer functionality as possible
@@ -89,6 +90,7 @@ models = [
     {"name": "gemma",          "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", },
     {"name": "gemma-2",        "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", },
     {"name": "jais",           "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", },
+    {"name": "t5",             "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", },
 ]
 
 
@@ -110,9 +112,13 @@ def download_model(model):
     os.makedirs(f"models/tokenizers/{name}", exist_ok=True)
 
     files = ["config.json", "tokenizer.json", "tokenizer_config.json"]
+
     if tokt == TOKENIZER_TYPE.SPM:
         files.append("tokenizer.model")
 
+    if tokt == TOKENIZER_TYPE.UGM:
+        files.append("spiece.model")
+
     for file in files:
         save_path = f"models/tokenizers/{name}/{file}"
         if os.path.isfile(save_path):
@@ -135,7 +141,7 @@ for model in models:
     name = model["name"]
     tokt = model["tokt"]
 
-    if tokt == TOKENIZER_TYPE.SPM:
+    if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM:
         continue
 
     # Skip if the tokenizer folder does not exist or there are other download issues previously
@@ -145,7 +151,10 @@ for model in models:
 
     # create the tokenizer
     try:
-        tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
+        if name == "t5":
+            tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
+        else:
+            tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
     except OSError as e:
         logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}")
         continue  # Skip to the next model if the tokenizer can't be loaded
@@ -266,6 +275,7 @@ tests = [
     "\n =",
     "' era",
     "Hello, y'all! How are you 😁 ?我想在apple工作1314151天~",
+    "!!!!!!",
     "3",
     "33",
     "333",
@@ -304,7 +314,10 @@ for model in models:
 
     # create the tokenizer
     try:
-        tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
+        if name == "t5":
+            tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
+        else:
+            tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
     except OSError as e:
         logger.error(f"Failed to load tokenizer for model {name}. Error: {e}")
         continue  # Skip this model and continue with the next one in the loop

+ 36 - 10
convert-hf-to-gguf.py

@@ -2853,11 +2853,17 @@ class DeepseekV2Model(Model):
                 raise ValueError(f"Unprocessed experts: {experts}")
 
 
-@Model.register("T5ForConditionalGeneration")
 @Model.register("T5WithLMHeadModel")
+@Model.register("T5ForConditionalGeneration")
+@Model.register("MT5ForConditionalGeneration")
+@Model.register("UMT5ForConditionalGeneration")
 class T5Model(Model):
     model_arch = gguf.MODEL_ARCH.T5
 
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.shared_token_embeddings_found = False
+
     def set_vocab(self):
         # to avoid TypeError: Descriptors cannot be created directly
         # exception when importing sentencepiece_model_pb2
@@ -2865,17 +2871,29 @@ class T5Model(Model):
         from sentencepiece import SentencePieceProcessor
         from sentencepiece import sentencepiece_model_pb2 as model
 
-        tokenizer_path = self.dir_model / 'spiece.model'
+        tokenizer_path = self.dir_model / 'tokenizer.model'
+
+        # many older models use spiece.model tokenizer model filename
+        if not tokenizer_path.is_file():
+            tokenizer_path = self.dir_model / 'spiece.model'
 
         if not tokenizer_path.is_file():
             raise FileNotFoundError(f"File not found: {tokenizer_path}")
 
         sentencepiece_model = model.ModelProto()
         sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
+
+        # some models like Pile-T5 family use BPE tokenizer instead of Unigram
+        if sentencepiece_model.trainer_spec.model_type == 2: # BPE
+            # assure the tokenizer model file name is correct
+            assert tokenizer_path.name == 'tokenizer.model'
+            return self._set_vocab_sentencepiece()
+        else:
+            assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
+
         add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
         remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
         precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
-        assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
 
         tokenizer = SentencePieceProcessor()
         tokenizer.LoadFromFile(str(tokenizer_path))
@@ -2945,7 +2963,10 @@ class T5Model(Model):
 
     def set_gguf_parameters(self):
         self.gguf_writer.add_name("T5")
-        self.gguf_writer.add_context_length(self.hparams["n_positions"])
+        if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
+            logger.warning("Couldn't find context length in config.json, assuming default value of 512")
+            n_ctx = 512
+        self.gguf_writer.add_context_length(n_ctx)
         self.gguf_writer.add_embedding_length(self.hparams["d_model"])
         self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
         self.gguf_writer.add_block_count(self.hparams["num_layers"])
@@ -2961,12 +2982,17 @@ class T5Model(Model):
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
         del bid  # unused
 
-        # Sometimes T5 and Flan-T5 based models contain "encoder.embed_tokens.weight" tensor or
-        # "decoder.embed_tokens.weight" tensors that are duplicates of "shared.weight" tensor
-        # To prevent errors caused by an unnecessary unmapped tensor, skip both of them and use only "shared.weight".
-        if name == "decoder.embed_tokens.weight" or name == "encoder.embed_tokens.weight":
-            logger.debug(f"Skipping tensor {name!r} in safetensors so that convert can end normally.")
-            return []
+        # T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight",
+        # "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored
+        # in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder
+        # and decoder and ignore the remaining ones.
+        if name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "shared.weight"]:
+            if not self.shared_token_embeddings_found:
+                name = "shared.weight"
+                self.shared_token_embeddings_found = True
+            else:
+                logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.")
+                return []
 
         return [(self.map_tensor_name(name), data_torch)]
 

+ 27 - 7
examples/batched/batched.cpp

@@ -93,14 +93,34 @@ int main(int argc, char ** argv) {
 
     // create a llama_batch
     // we use this object to submit token data for decoding
-    llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0, 1);
+    llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t) n_parallel), 0, n_parallel);
+
+    std::vector<llama_seq_id> seq_ids(n_parallel, 0);
+    for (int32_t i = 0; i < n_parallel; ++i) {
+        seq_ids[i] = i;
+    }
 
     // evaluate the initial prompt
     for (size_t i = 0; i < tokens_list.size(); ++i) {
-        llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
+        llama_batch_add(batch, tokens_list[i], i, seq_ids, false);
     }
     GGML_ASSERT(batch.n_tokens == (int) tokens_list.size());
 
+    if (llama_model_has_encoder(model)) {
+        if (llama_encode(ctx, batch)) {
+            LOG_TEE("%s : failed to eval\n", __func__);
+            return 1;
+        }
+
+        llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
+        if (decoder_start_token_id == -1) {
+            decoder_start_token_id = llama_token_bos(model);
+        }
+
+        llama_batch_clear(batch);
+        llama_batch_add(batch, decoder_start_token_id, 0, seq_ids, false);
+    }
+
     // llama_decode will output logits only for the last token of the prompt
     batch.logits[batch.n_tokens - 1] = true;
 
@@ -109,11 +129,11 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    // assign the system KV cache to all parallel sequences
-    // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
-    for (int32_t i = 1; i < n_parallel; ++i) {
-        llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
-    }
+    //// assign the system KV cache to all parallel sequences
+    //// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
+    //for (int32_t i = 1; i < n_parallel; ++i) {
+    //    llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
+    //}
 
     if (n_parallel > 1) {
         LOG_TEE("\n\n%s: generating %d sequences ...\n", __func__, n_parallel);

+ 21 - 1
examples/main/main.cpp

@@ -255,7 +255,9 @@ int main(int argc, char ** argv) {
     }
 
     const bool add_bos = llama_should_add_bos_token(model);
-    GGML_ASSERT(llama_add_eos_token(model) != 1);
+    if (!llama_model_has_encoder(model)) {
+        GGML_ASSERT(llama_add_eos_token(model) != 1);
+    }
     LOG("add_bos: %d\n", add_bos);
 
     std::vector<llama_token> embd_inp;
@@ -517,6 +519,24 @@ int main(int argc, char ** argv) {
         exit(1);
     }
 
+    if (llama_model_has_encoder(model)) {
+        int enc_input_size = embd_inp.size();
+        llama_token * enc_input_buf = embd_inp.data();
+
+        if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) {
+            LOG_TEE("%s : failed to eval\n", __func__);
+            return 1;
+        }
+
+        llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
+        if (decoder_start_token_id == -1) {
+            decoder_start_token_id = llama_token_bos(model);
+        }
+
+        embd_inp.clear();
+        embd_inp.push_back(decoder_start_token_id);
+    }
+
     while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
         // predict
         if (!embd.empty()) {

+ 15 - 0
include/llama.h

@@ -485,6 +485,13 @@ extern "C" {
     // Get a llama model tensor
     LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
 
+    // Returns true if the model contains an encoder that requires llama_encode() call
+    LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);
+
+    // For encoder-decoder models, this function returns id of the token that must be provided
+    // to the decoder to start generating output sequence. For other models, it returns -1.
+    LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
+
     // Returns 0 on success
     LLAMA_API uint32_t llama_model_quantize(
             const char * fname_inp,
@@ -770,6 +777,14 @@ extern "C" {
     // Frees a batch of tokens allocated with llama_batch_init()
     LLAMA_API void llama_batch_free(struct llama_batch batch);
 
+    // Processes a batch of tokens with the ecoder part of the encoder-decoder model.
+    // Stores the encoder output internally for later use by the decoder cross-attention layers.
+    //   0 - success
+    // < 0 - error
+    LLAMA_API int32_t llama_encode(
+            struct llama_context * ctx,
+              struct llama_batch   batch);
+
     // Positive return values does not mean a fatal error, but rather a warning.
     //   0 - success
     //   1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)

+ 2 - 0
models/ggml-vocab-bert-bge.gguf.inp

@@ -73,6 +73,8 @@ __ggml_vocab_test__
 __ggml_vocab_test__
 Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
 __ggml_vocab_test__
+!!!!!!
+__ggml_vocab_test__
 3
 __ggml_vocab_test__
 33

+ 1 - 0
models/ggml-vocab-bert-bge.gguf.out

@@ -31,6 +31,7 @@
  1027
  1005 3690
  7592 1010 1061 1005 2035 999 2129 2024 2017 100 1029 1855 100 100 6207 100 100 14677 23632 22203 1811 1995
+ 999 999 999 999 999 999
  1017
  3943
  21211

+ 2 - 0
models/ggml-vocab-command-r.gguf.inp

@@ -73,6 +73,8 @@ __ggml_vocab_test__
 __ggml_vocab_test__
 Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
 __ggml_vocab_test__
+!!!!!!
+__ggml_vocab_test__
 3
 __ggml_vocab_test__
 33

+ 1 - 0
models/ggml-vocab-command-r.gguf.out

@@ -31,6 +31,7 @@
  206 1857
  14 4515
  28339 19 1770 14 1954 8 4070 1955 1933 80503 231 5691 12081 13336 2648 29325 14315 24 26 24 27 24 28 24 5123 18372
+ 57178 10251
  26
  26 26
  26 26 26

+ 2 - 0
models/ggml-vocab-deepseek-coder.gguf.inp

@@ -73,6 +73,8 @@ __ggml_vocab_test__
 __ggml_vocab_test__
 Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
 __ggml_vocab_test__
+!!!!!!
+__ggml_vocab_test__
 3
 __ggml_vocab_test__
 33

+ 1 - 0
models/ggml-vocab-deepseek-coder.gguf.out

@@ -31,6 +31,7 @@
  185 405
  6 2895
  17535 11 320 6 435 0 1717 417 340 12394 233 210 3015 19100 608 9413 2668 16 18 16 19 16 20 16 1393 169 121 239
+ 15330 3023
  18
  18 18
  18 18 18

+ 2 - 0
models/ggml-vocab-deepseek-llm.gguf.inp

@@ -73,6 +73,8 @@ __ggml_vocab_test__
 __ggml_vocab_test__
 Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
 __ggml_vocab_test__
+!!!!!!
+__ggml_vocab_test__
 3
 __ggml_vocab_test__
 33

+ 1 - 0
models/ggml-vocab-deepseek-llm.gguf.out

@@ -31,6 +31,7 @@
  185 403
  6 2906
  17464 11 320 6 436 0 1724 418 340 33701 210 3025 19017 612 9407 2681 16 18 16 19 16 20 16 1398 68940 239
+ 15278 3033
  18
  18 18
  18 18 18

+ 2 - 0
models/ggml-vocab-falcon.gguf.inp

@@ -73,6 +73,8 @@ __ggml_vocab_test__
 __ggml_vocab_test__
 Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
 __ggml_vocab_test__
+!!!!!!
+__ggml_vocab_test__
 3
 __ggml_vocab_test__
 33

+ 1 - 0
models/ggml-vocab-falcon.gguf.out

@@ -31,6 +31,7 @@
  1212 40
  18 4932
  9856 23 291 18 436 12 1265 362 299 8196 207 204 42 50087 123 2727 20300 32022 133 234 17419 30137 28 7858 181 133 236
+ 51520
  30
  3138
  22287

+ 2 - 0
models/ggml-vocab-gpt-2.gguf.inp

@@ -73,6 +73,8 @@ __ggml_vocab_test__
 __ggml_vocab_test__
 Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
 __ggml_vocab_test__
+!!!!!!
+__ggml_vocab_test__
 3
 __ggml_vocab_test__
 33

+ 1 - 0
models/ggml-vocab-gpt-2.gguf.out

@@ -31,6 +31,7 @@
  198 796
  6 6980
  15496 11 331 6 439 0 1374 389 345 30325 223 5633 22755 239 46349 111 28839 101 18040 32432 98 43291 1485 1415 24309 25465 171 121 252
+ 13896 3228
  18
  2091
  20370

+ 2 - 0
models/ggml-vocab-llama-bpe.gguf.inp

@@ -73,6 +73,8 @@ __ggml_vocab_test__
 __ggml_vocab_test__
 Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
 __ggml_vocab_test__
+!!!!!!
+__ggml_vocab_test__
 3
 __ggml_vocab_test__
 33

+ 1 - 0
models/ggml-vocab-llama-bpe.gguf.out

@@ -31,6 +31,7 @@
  198 284
  6 11639
  9906 11 379 65948 0 2650 527 499 27623 223 949 37046 101067 19000 23182 102301 9263 18136 16 36827 21909
+ 17523 3001
  18
  1644
  8765

+ 2 - 0
models/ggml-vocab-llama-spm.gguf.inp

@@ -73,6 +73,8 @@ __ggml_vocab_test__
 __ggml_vocab_test__
 Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
 __ggml_vocab_test__
+!!!!!!
+__ggml_vocab_test__
 3
 __ggml_vocab_test__
 33

+ 1 - 0
models/ggml-vocab-llama-spm.gguf.out

@@ -31,6 +31,7 @@
  29871 13 353
  525 3152
  15043 29892 343 29915 497 29991 1128 526 366 29871 243 162 155 132 1577 30672 31522 30505 11548 31041 30732 29896 29941 29896 29946 29896 29945 29896 30408 30739
+ 1738 6824 21004
  29871 29941
  29871 29941 29941
  29871 29941 29941 29941

+ 2 - 0
models/ggml-vocab-mpt.gguf.inp

@@ -73,6 +73,8 @@ __ggml_vocab_test__
 __ggml_vocab_test__
 Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
 __ggml_vocab_test__
+!!!!!!
+__ggml_vocab_test__
 3
 __ggml_vocab_test__
 33

+ 1 - 0
models/ggml-vocab-mpt.gguf.out

@@ -31,6 +31,7 @@
  187 426
  8 8685
  12092 13 340 8 455 2 1359 403 368 49042 212 3736 15367 41197 13610 19934 41869 21275 1012 1047 18795 40120 20422 241
+ 18963 4672
  20
  1610
  20084

+ 2 - 0
models/ggml-vocab-phi-3.gguf.inp

@@ -73,6 +73,8 @@ __ggml_vocab_test__
 __ggml_vocab_test__
 Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
 __ggml_vocab_test__
+!!!!!!
+__ggml_vocab_test__
 3
 __ggml_vocab_test__
 33

+ 1 - 0
models/ggml-vocab-phi-3.gguf.out

@@ -31,6 +31,7 @@
  29871 13 353
  525 3152
  15043 29892 343 29915 497 29991 1128 526 366 29871 243 162 155 132 1577 30672 31522 30505 11548 31041 30732 29896 29941 29896 29946 29896 29945 29896 30408 30739
+ 1738 6824 21004
  29871 29941
  29871 29941 29941
  29871 29941 29941 29941

+ 2 - 0
models/ggml-vocab-qwen2.gguf.inp

@@ -73,6 +73,8 @@ __ggml_vocab_test__
 __ggml_vocab_test__
 Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
 __ggml_vocab_test__
+!!!!!!
+__ggml_vocab_test__
 3
 __ggml_vocab_test__
 33

+ 1 - 0
models/ggml-vocab-qwen2.gguf.out

@@ -31,6 +31,7 @@
  198 284
  6 11385
  9707 11 379 64848 0 2585 525 498 26525 223 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216
+ 17085 2928
  18
  18 18
  18 18 18

+ 2 - 0
models/ggml-vocab-refact.gguf.inp

@@ -73,6 +73,8 @@ __ggml_vocab_test__
 __ggml_vocab_test__
 Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
 __ggml_vocab_test__
+!!!!!!
+__ggml_vocab_test__
 3
 __ggml_vocab_test__
 33

+ 1 - 0
models/ggml-vocab-refact.gguf.out

@@ -31,6 +31,7 @@
  203 280
  25 34666
  8279 30 533 25 464 19 4971 884 844 18458 228 1018 4982 13368 2909 9513 17827 35 37 35 38 35 39 35 11873 47838
+ 9163 3202
  37
  37 37
  37 37 37

+ 2 - 0
models/ggml-vocab-starcoder.gguf.inp

@@ -73,6 +73,8 @@ __ggml_vocab_test__
 __ggml_vocab_test__
 Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
 __ggml_vocab_test__
+!!!!!!
+__ggml_vocab_test__
 3
 __ggml_vocab_test__
 33

+ 1 - 0
models/ggml-vocab-starcoder.gguf.out

@@ -31,6 +31,7 @@
  222 299
  44 34719
  8302 49 553 44 483 38 4998 904 863 18445 247 1037 4995 13379 2924 9515 17823 54 56 54 57 54 58 54 11904 47892
+ 9221 3226
  56
  56 56
  56 56 56

Разлика између датотеке није приказан због своје велике величине
+ 745 - 6
src/llama.cpp


Неке датотеке нису приказане због велике количине промена