Selaa lähdekoodia

llm : add Falcon support (#2717)

* llama : refactor GGUF constants into static maps

* llama : check if model architecture is known

* llama : refactor llama_model_load_internal()

* gguf : add KV constant maps

* llm : read arch-specific KVs

* convert : add dummy scores + types

* falcon : load tensor data (CPU only)

* llama : fix loading progress bar

* llama : add arch member to llama_model

* falcon : CPU inference working

* falcon : support non-40B models

* falcon : minor

* llama : minor updates

ggml-ci

* convert-falcon-hf-to-gguf.py : fix special token mapping

* llama.cpp : llama default UNK token = id 0

* llama.cpp : fix bpe tokenizer

* llama.cpp : fix the fix of bpe tokenizer

* ggml : pass eps to ggml_norm

* metal : implement RoPE (mode = 2) + avoid ggml_repeat

* ggml : ggml_repeat always creates new tensor

* falcon : copy-paste self-attention from LLaMA

* metal : print extra compute pipeline info

* falcon : minor changes (still chasing the Metal problem)

* llama.cpp : fix linefeed token

* metal : fix GELU kernel numerical stability by using precise::tanh

* metal : temporary workaround for the concurrency optimization bug

* falcon : add CUDA offloading (#2739)

* llama : better model naming and size reporting

* llama : prep new tokenizer support

* llama : advanced BPE tokenizer based on ggllm.cpp imlpementation

* llama : remove oboslete comment

ggml-ci

* common : remove obsolete BPE API + disable test-tokenizer-1

* llama : revert BPE special-case in llama_byte_to_token()

* cuda : add TODOs for RoPE NeoX implementation

* llama : default special tokens based on vocab type

* perplexity : add log for start of tokenization

---------

Co-authored-by: klosax <131523366+klosax@users.noreply.github.com>
Co-authored-by: slaren <slarengh@gmail.com>
Georgi Gerganov 2 vuotta sitten
vanhempi
sitoutus
cf658ad
18 muutettua tiedostoa jossa 877 lisäystä ja 473 poistoa
  1. 0 32
      common/common.cpp
  2. 0 9
      common/common.h
  3. 25 30
      convert-falcon-hf-to-gguf.py
  4. 5 1
      convert.py
  5. 8 6
      examples/main/main.cpp
  6. 21 10
      examples/perplexity/perplexity.cpp
  7. 2 2
      ggml-alloc.c
  8. 1 1
      ggml-alloc.h
  9. 28 1
      ggml-cuda.cu
  10. 68 64
      ggml-metal.m
  11. 25 2
      ggml-metal.metal
  12. 15 15
      ggml.c
  13. 4 3
      ggml.h
  14. 13 13
      gguf.py
  15. 654 258
      llama.cpp
  16. 2 13
      llama.h
  17. 2 1
      tests/CMakeLists.txt
  18. 4 12
      tests/test-tokenizer-1.cpp

+ 0 - 32
common/common.cpp

@@ -744,35 +744,3 @@ std::string llama_token_to_str(const struct llama_context * ctx, llama_token tok
 
     return std::string(result.data(), result.size());
 }
-
-std::vector<llama_token> llama_tokenize_bpe(
-        struct llama_context * ctx,
-           const std::string & text,
-                        bool   add_bos) {
-    int n_tokens = text.length() + add_bos;
-    std::vector<llama_token> result(n_tokens);
-    n_tokens = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos);
-    if (n_tokens < 0) {
-        result.resize(-n_tokens);
-        int check = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos);
-        GGML_ASSERT(check == -n_tokens);
-    } else {
-        result.resize(n_tokens);
-    }
-    return result;
-}
-
-std::string llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token) {
-    std::vector<char> result(8, 0);
-    const int n_tokens = llama_token_to_str_bpe(ctx, token, result.data(), result.size());
-    if (n_tokens < 0) {
-        result.resize(-n_tokens);
-        const int check = llama_token_to_str_bpe(ctx, token, result.data(), result.size());
-        GGML_ASSERT(check == -n_tokens);
-    } else {
-        result.resize(n_tokens);
-    }
-
-    return std::string(result.data(), result.size());
-}
-

+ 0 - 9
common/common.h

@@ -120,15 +120,6 @@ std::vector<llama_token> llama_tokenize(
            const std::string & text,
                         bool   add_bos);
 
-std::vector<llama_token> llama_tokenize_bpe(
-        struct llama_context * ctx,
-           const std::string & text,
-                        bool   add_bos);
-
 std::string llama_token_to_str(
         const struct llama_context * ctx,
                        llama_token   token);
-
-std::string llama_token_to_str_bpe(
-    const struct llama_context * ctx,
-                   llama_token   token);

+ 25 - 30
convert-falcon-hf-to-gguf.py

@@ -95,14 +95,17 @@ print("gguf: get model metadata")
 
 block_count = hparams["n_layer"]
 
-gguf_writer.add_name(last_dir)
+gguf_writer.add_name("Falcon")
 gguf_writer.add_context_length(2048) # not in config.json
 gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
 gguf_writer.add_embedding_length(hparams["hidden_size"])
 gguf_writer.add_feed_forward_length(4 * hparams["hidden_size"])
 gguf_writer.add_block_count(block_count)
 gguf_writer.add_head_count(hparams["n_head"])
-if "n_head_kv" in hparams: gguf_writer.add_head_count_kv(hparams["n_head_kv"])
+if "n_head_kv" in hparams:
+    gguf_writer.add_head_count_kv(hparams["n_head_kv"])
+else:
+    gguf_writer.add_head_count_kv(1)
 gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"])
 
 # TOKENIZATION
@@ -110,6 +113,8 @@ gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"])
 print("gguf: get tokenizer metadata")
 
 tokens: List[str] = []
+scores: List[float] = []
+toktypes: List[int] = []
 merges: List[str] = []
 
 
@@ -153,41 +158,30 @@ if Path(dir_model + "/tokenizer.json").is_file():
             text = bytearray(pad_token)
 
         tokens.append(text)
+        scores.append(0.0)                      # dymmy
+        toktypes.append(gguf.TokenType.NORMAL)  # dummy
 
     gguf_writer.add_token_list(tokens)
+    gguf_writer.add_token_scores(scores)
+    gguf_writer.add_token_types(toktypes)
 
-    if "added_tokens" in tokenizer_json and Path(dir_model + "/tokenizer_config.json").is_file():
-        print("gguf: get special token ids")
-
-        with open(dir_model + "/tokenizer_config.json", "r", encoding="utf-8") as f:
-            tokenizer_config = json.load(f)
+print("gguf: get special token ids")
+# Look for special tokens in config.json
 
-        # find special token ids
+if "bos_token_id" in hparams and hparams["bos_token_id"] != None:
+    gguf_writer.add_bos_token_id(hparams["bos_token_id"])
 
-        if "bos_token" in tokenizer_config:
-            for key in tokenizer_json["added_tokens"]:
-                if key["content"] == tokenizer_config["bos_token"]:
-                    gguf_writer.add_bos_token_id(key["id"])
+if "eos_token_id" in hparams and hparams["eos_token_id"] != None:
+    gguf_writer.add_eos_token_id(hparams["eos_token_id"])
 
-        if "eos_token" in tokenizer_config:
-            for key in tokenizer_json["added_tokens"]:
-                if key["content"] == tokenizer_config["eos_token"]:
-                    gguf_writer.add_eos_token_id(key["id"])
+if "unk_token_id" in hparams and hparams["unk_token_id"] != None:
+    gguf_writer.add_unk_token_id(hparams["unk_token_id"])
 
-        if "unk_token" in tokenizer_config:
-            for key in tokenizer_json["added_tokens"]:
-                if key["content"] == tokenizer_config["unk_token"]:
-                    gguf_writer.add_unk_token_id(key["id"])
+if "sep_token_id" in hparams and hparams["sep_token_id"] != None:
+    gguf_writer.add_sep_token_id(hparams["sep_token_id"])
 
-        if "sep_token" in tokenizer_config:
-            for key in tokenizer_json["added_tokens"]:
-                if key["content"] == tokenizer_config["sep_token"]:
-                    gguf_writer.add_sep_token_id(key["id"])
-
-        if "pad_token" in tokenizer_config:
-            for key in tokenizer_json["added_tokens"]:
-                if key["content"] == tokenizer_config["pad_token"]:
-                    gguf_writer.add_pad_token_id(key["id"])
+if "pad_token_id" in hparams and hparams["pad_token_id"] != None:
+    gguf_writer.add_pad_token_id(hparams["pad_token_id"])
 
 
 # TENSORS
@@ -195,8 +189,9 @@ if Path(dir_model + "/tokenizer.json").is_file():
 tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
 
 # params for qkv transform
-n_head = hparams["n_head"]
+n_head    = hparams["n_head"]
 n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else 1
+
 head_dim = hparams["hidden_size"] // n_head
 
 # tensor info

+ 5 - 1
convert.py

@@ -733,7 +733,11 @@ class OutputFile:
         self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
 
     def add_meta_arch(self, params: Params) -> None:
-        self.gguf.add_name                ("LLaMA")
+        ver = None
+        if (params.n_ctx == 4096):
+            ver = "v2"
+
+        self.gguf.add_name                ("LLaMA" if ver == None else "LLaMA " + ver)
         self.gguf.add_context_length      (params.n_ctx)
         self.gguf.add_embedding_length    (params.n_embd)
         self.gguf.add_block_count         (params.n_layer)

+ 8 - 6
examples/main/main.cpp

@@ -43,7 +43,7 @@ static bool is_interacting = false;
 void sigint_handler(int signo) {
     if (signo == SIGINT) {
         if (!is_interacting) {
-            is_interacting=true;
+            is_interacting = true;
         } else {
             console::cleanup();
             printf("\n");
@@ -189,10 +189,12 @@ int main(int argc, char ** argv) {
         }
     }
 
+    const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
+
     // tokenize the prompt
     std::vector<llama_token> embd_inp;
     if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
-        embd_inp = ::llama_tokenize(ctx, params.prompt, true);
+        embd_inp = ::llama_tokenize(ctx, params.prompt, is_spm);
     } else {
         embd_inp = session_tokens;
     }
@@ -208,9 +210,9 @@ int main(int argc, char ** argv) {
     int original_prompt_len = 0;
     if (ctx_guidance) {
         params.cfg_negative_prompt.insert(0, 1, ' ');
-        guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, true);
+        guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, is_spm);
 
-        std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true);
+        std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, is_spm);
         original_prompt_len = original_inp.size();
         guidance_offset = (int)guidance_inp.size() - original_prompt_len;
     }
@@ -257,8 +259,8 @@ int main(int argc, char ** argv) {
     }
 
     // prefix & suffix for instruct mode
-    const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true);
-    const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false);
+    const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", is_spm);
+    const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n",    false);
 
     // in instruct mode, we inject a prefix and a suffix to each input by the user
     if (params.instruct) {

+ 21 - 10
examples/perplexity/perplexity.cpp

@@ -28,7 +28,6 @@ std::vector<float> softmax(const std::vector<float>& logits) {
 }
 
 void perplexity_v2(llama_context * ctx, const gpt_params & params) {
-
     // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
     // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
     // Output: `perplexity: 13.5106 [114/114]`
@@ -38,7 +37,13 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
         fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
         return;
     }
-    auto tokens = ::llama_tokenize(ctx, params.prompt, true);
+
+    const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
+    const bool add_bos = is_spm;
+
+    fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
+
+    auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
 
     const int calc_chunk = params.n_ctx;
 
@@ -86,7 +91,7 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
             const auto token_org = tokens[batch_start];
 
             // add BOS token for the first batch of each chunk
-            if (j == 0) {
+            if (add_bos && j == 0) {
                 tokens[batch_start] = llama_token_bos(ctx);
             }
 
@@ -136,7 +141,6 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
 }
 
 void perplexity(llama_context * ctx, const gpt_params & params) {
-
     if (params.ppl_stride > 0) {
         perplexity_v2(ctx, params);
         return;
@@ -146,7 +150,13 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
     // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
     // Output: `perplexity: 13.5106 [114/114]`
     // BOS tokens will be added for each chunk before eval
-    auto tokens = ::llama_tokenize(ctx, params.prompt, true);
+
+    const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
+    const bool add_bos = is_spm;
+
+    fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
+
+    auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
 
     const int n_chunk_max = tokens.size() / params.n_ctx;
 
@@ -177,7 +187,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
             const auto token_org = tokens[batch_start];
 
             // add BOS token for the first batch of each chunk
-            if (j == 0) {
+            if (add_bos && j == 0) {
                 tokens[batch_start] = llama_token_bos(ctx);
             }
 
@@ -295,8 +305,10 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
     size_t hs_task_count = prompt_lines.size()/6;
     fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
 
+    const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
+
     // This is needed as usual for LLaMA models
-    bool prepend_bos = true;
+    const bool add_bos = is_spm;
 
     // Number of tasks to use when computing the score
     if ( params.hellaswag_tasks < hs_task_count  ) {
@@ -352,14 +364,13 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
     std::vector<float> tok_logits(n_vocab);
 
     for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
-
         // Tokenize the context to count tokens
-        std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, prepend_bos);
+        std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, add_bos);
         size_t context_size = context_embd.size();
 
         // Do the 1st ending
         // In this case we include the context when evaluating
-        auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], prepend_bos);
+        auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], add_bos);
         auto query_size = query_embd.size();
         //printf("First query: %d\n",(int)query_size);
 

+ 2 - 2
ggml-alloc.c

@@ -238,7 +238,7 @@ static void ggml_allocator_free_tensor(struct ggml_allocr * alloc, struct ggml_t
     alloc->n_free_blocks++;
 }
 
-void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n) {
+void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n) {
     int pos = 0;
     for (int i = 0; i < n; i++) {
         if (list[i] != -1) {
@@ -547,7 +547,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
                         struct ggml_tensor * view_src = get_view_source(parent);
                         struct hash_node * view_src_hn = hash_get(ht, view_src);
                         view_src_hn->n_views -= 1;
-                        AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src->n_children, view_src->n_views);
+                        AT_PRINTF("view_src %s\n", view_src->name);
                         if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
                             ggml_allocator_free_tensor(alloc, view_src);
                         }

+ 1 - 1
ggml-alloc.h

@@ -12,7 +12,7 @@ GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment);
 
 // tell the allocator to parse nodes following the order described in the list
 // you should call this if your graph are optimized to execute out-of-order
-GGML_API void   ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n);
+GGML_API void   ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n);
 
 GGML_API void   ggml_allocr_free(struct ggml_allocr * alloc);
 GGML_API bool   ggml_allocr_is_measure(struct ggml_allocr * alloc);

+ 28 - 1
ggml-cuda.cu

@@ -3907,6 +3907,29 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
     dst[i + 1] = x0*sin_theta + x1*cos_theta;
 }
 
+// TODO: this implementation is wrong!
+//static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0,
+//                                const float p_delta, const int p_delta_rows, const float theta_scale) {
+//    const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+//
+//    if (col >= ncols) {
+//        return;
+//    }
+//
+//    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+//    const int i = row*ncols + col/2;
+//
+//    const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
+//    const float sin_theta = sinf(theta);
+//    const float cos_theta = cosf(theta);
+//
+//    const float x0 = x[i + 0];
+//    const float x1 = x[i + ncols/2];
+//
+//    dst[i + 0]       = x0*cos_theta - x1*sin_theta;
+//    dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
+//}
+
 static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) {
     const int col = blockDim.x*blockIdx.x + threadIdx.x;
     const int half_n_dims = ncols/4;
@@ -5515,7 +5538,8 @@ inline void ggml_cuda_op_rope(
 
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
-    const bool is_glm = mode & 4;
+    const bool is_neox = mode & 2;
+    const bool is_glm  = mode & 4;
 
     // compute
     if (is_glm) {
@@ -5523,6 +5547,9 @@ inline void ggml_cuda_op_rope(
         const float id_p = min(p, n_ctx - 2.f);
         const float block_p = max(p - (n_ctx - 2.f), 0.f);
         rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
+    } else if (is_neox) {
+        GGML_ASSERT(false && "RoPE NeoX not implemented yet");
+#pragma message("TODO: implement RoPE NeoX for CUDA")
     } else {
         const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
         rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);

+ 68 - 64
ggml-metal.m

@@ -167,7 +167,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 #define GGML_METAL_ADD_KERNEL(name) \
         ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
         ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
-        fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); \
+        fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
+                (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
+                (int) ctx->pipeline_##name.threadExecutionWidth); \
         if (error) { \
             fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
             return NULL; \
@@ -218,12 +220,12 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 #undef GGML_METAL_ADD_KERNEL
     }
 
-    fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
-    fprintf(stderr, "%s: hasUnifiedMemory             = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
+    fprintf(stderr, "%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
+    fprintf(stderr, "%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
     if (ctx->device.maxTransferRate != 0) {
-        fprintf(stderr, "%s: maxTransferRate              = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
+        fprintf(stderr, "%s: maxTransferRate               = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
     } else {
-        fprintf(stderr, "%s: maxTransferRate              = built-in GPU\n", __func__);
+        fprintf(stderr, "%s: maxTransferRate               = built-in GPU\n", __func__);
     }
 
     return ctx;
@@ -537,8 +539,8 @@ void ggml_metal_graph_compute(
 
             id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
 
-            const int node_start =                                  (cb_idx + 0) * n_nodes_per_cb;
-            const int node_end   = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
+            const int node_start =                                      (cb_idx + 0) * n_nodes_per_cb;
+            const int node_end   = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
 
             for (int ind = node_start; ind < node_end; ++ind) {
                 const int i = has_concur ? ctx->concur_list[ind] : ind;
@@ -744,32 +746,31 @@ void ggml_metal_graph_compute(
                                 [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
                                 ne00%32 == 0 &&
                                 ne11 > 1) {
-                                    switch (src0->type) {
-                                        case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
-                                        case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
-                                        case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
-                                        case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
-                                        case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
-                                        case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
-                                        case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
-                                        case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
-                                        default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
-                                    }
-                                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                                    [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
-                                    [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                                    [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
-                                    [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
-                                    [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
-                                    [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
-                                    [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
-                                    [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
-                                    [encoder setThreadgroupMemoryLength:8192 atIndex:0];
-                                    [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                                switch (src0->type) {
+                                    case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
+                                    case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
+                                    case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
+                                    case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
+                                    case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
+                                    case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
+                                    case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
+                                    case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
+                                    default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
                                 }
-                            else {
+                                [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
+                                [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
+                                [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
+                                [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:3];
+                                [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:4];
+                                [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:5];
+                                [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:6];
+                                [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:7];
+                                [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:8];
+                                [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:9];
+                                [encoder setBytes:&gqa     length:sizeof(gqa)  atIndex:10];
+                                [encoder setThreadgroupMemoryLength:8192 atIndex:0];
+                                [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                            } else {
                                 int nth0 = 32;
                                 int nth1 = 1;
 
@@ -868,24 +869,24 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
                                 [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:15];
                                 [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:16];
-                                [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
+                                [encoder setBytes:&gqa  length:sizeof(gqa)  atIndex:17];
 
                                 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
                                     src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q3_K) {
 #ifdef GGML_QKK_64
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 #else
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 #endif
                                 }
                                 else if (src0t == GGML_TYPE_Q5_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q6_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 } else {
                                     [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
                                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -938,16 +939,17 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_NORM:
                         {
-                            const float eps = 1e-5f;
+                            float eps;
+                            memcpy(&eps, dst->op_params, sizeof(float));
 
                             const int nth = 256;
 
                             [encoder setComputePipelineState:ctx->pipeline_norm];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
-                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
-                            [encoder setBytes:&eps  length:sizeof(   float) atIndex:4];
+                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                            [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
+                            [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
                             [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
 
                             const int64_t nrows = ggml_nrows(src0);
@@ -990,7 +992,9 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:16];
                             [encoder setBytes:&nb3  length:sizeof(uint64_t) atIndex:17];
                             [encoder setBytes:&m0  length:sizeof(    float) atIndex:18];
+
                             const int nth = 32;
+
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
                     case GGML_OP_ROPE:
@@ -1005,8 +1009,8 @@ void ggml_metal_graph_compute(
                             memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
 
                             [encoder setComputePipelineState:ctx->pipeline_rope];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
                             [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
                             [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
                             [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
@@ -1057,24 +1061,24 @@ void ggml_metal_graph_compute(
                                 default: GGML_ASSERT(false && "not implemented");
                             }
 
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
-                            [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
-                            [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
-                            [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
-                            [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
-                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
-                            [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
-                            [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
-                            [encoder setBytes:&ne0  length:sizeof( int64_t) atIndex:10];
-                            [encoder setBytes:&ne1  length:sizeof( int64_t) atIndex:11];
-                            [encoder setBytes:&ne2  length:sizeof( int64_t) atIndex:12];
-                            [encoder setBytes:&ne3  length:sizeof( int64_t) atIndex:13];
-                            [encoder setBytes:&nb0  length:sizeof(uint64_t) atIndex:14];
-                            [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:15];
-                            [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:16];
-                            [encoder setBytes:&nb3  length:sizeof(uint64_t) atIndex:17];
+                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                            [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
+                            [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:5];
+                            [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:6];
+                            [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:7];
+                            [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:8];
+                            [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:9];
+                            [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:10];
+                            [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:11];
+                            [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:12];
+                            [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:13];
+                            [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:14];
+                            [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:15];
+                            [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:16];
+                            [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:17];
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;

+ 25 - 2
ggml-metal.metal

@@ -87,7 +87,12 @@ kernel void kernel_gelu(
     device       float * dst,
     uint tpig[[thread_position_in_grid]]) {
     float x = src0[tpig];
-    dst[tpig] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+
+    // BEWARE !!!
+    // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
+    // This was observed with Falcon 7B and 40B models
+    //
+    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
 }
 
 kernel void kernel_soft_max(
@@ -571,7 +576,25 @@ kernel void kernel_rope(
             dst_data[1] = x0*sin_theta + x1*cos_theta;
         }
     } else {
-        // TODO: implement
+        for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
+            for (int64_t ic = 0; ic < n_dims; ic += 2) {
+                const float cos_theta = cos(theta);
+                const float sin_theta = sin(theta);
+
+                theta *= theta_scale;
+
+                const int64_t i0 = ib*n_dims + ic/2;
+
+                device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                device       float * dst_data  = (device float *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                const float x0 = src[0];
+                const float x1 = src[n_dims/2];
+
+                dst_data[0]        = x0*cos_theta - x1*sin_theta;
+                dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+            }
+        }
     }
 }
 

+ 15 - 15
ggml.c

@@ -3554,9 +3554,9 @@ inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) {
 inline static void ggml_vec_elu_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
 inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
 
-static const float GELU_COEF_A    = 0.044715f;
-static const float GELU_QUICK_COEF    = -1.702f;
-static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+static const float GELU_COEF_A     = 0.044715f;
+static const float GELU_QUICK_COEF = -1.702f;
+static const float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
 
 inline static float ggml_gelu_f32(float x) {
     return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
@@ -5555,10 +5555,6 @@ struct ggml_tensor * ggml_repeat(
         is_node = true;
     }
 
-    if (ggml_are_same_shape(a, b) && !is_node) {
-        return a;
-    }
-
     struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
 
     result->op   = GGML_OP_REPEAT;
@@ -5789,6 +5785,7 @@ struct ggml_tensor * ggml_silu_back(
 static struct ggml_tensor * ggml_norm_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
+        float eps,
         bool inplace) {
     bool is_node = false;
 
@@ -5799,7 +5796,7 @@ static struct ggml_tensor * ggml_norm_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    // TODO: maybe store epsilon here?
+    ggml_set_op_params(result, &eps, sizeof(eps));
 
     result->op   = GGML_OP_NORM;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5810,14 +5807,16 @@ static struct ggml_tensor * ggml_norm_impl(
 
 struct ggml_tensor * ggml_norm(
         struct ggml_context * ctx,
-        struct ggml_tensor  * a) {
-    return ggml_norm_impl(ctx, a, false);
+        struct ggml_tensor  * a,
+        float eps) {
+    return ggml_norm_impl(ctx, a, eps, false);
 }
 
 struct ggml_tensor * ggml_norm_inplace(
         struct ggml_context * ctx,
-        struct ggml_tensor  * a) {
-    return ggml_norm_impl(ctx, a, true);
+        struct ggml_tensor  * a,
+        float eps) {
+    return ggml_norm_impl(ctx, a, eps, true);
 }
 
 // ggml_rms_norm
@@ -10619,7 +10618,8 @@ static void ggml_compute_forward_norm_f32(
 
     GGML_TENSOR_UNARY_OP_LOCALS;
 
-    const float eps = 1e-5f; // TODO: make this a parameter
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
 
     // TODO: optimize
     for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -12537,7 +12537,7 @@ static void ggml_compute_forward_rope_f32(
                         dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
                     }
                 } else {
-                    // TODO: this is probably wrong, but I can't figure it out ..
+                    // TODO: this might be wrong for ne0 != n_dims - need double check
                     // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
                         for (int64_t ic = 0; ic < n_dims; ic += 2) {
@@ -12666,7 +12666,7 @@ static void ggml_compute_forward_rope_f16(
                         dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
                     }
                 } else {
-                    // TODO: this is probably wrong, but I can't figure it out ..
+                    // TODO: this might be wrong for ne0 != n_dims - need double check
                     // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
                         for (int64_t ic = 0; ic < n_dims; ic += 2) {

+ 4 - 3
ggml.h

@@ -909,14 +909,15 @@ extern "C" {
             struct ggml_tensor  * b);
 
     // normalize along rows
-    // TODO: eps is hardcoded to 1e-5 for now
     GGML_API struct ggml_tensor * ggml_norm(
             struct ggml_context * ctx,
-            struct ggml_tensor  * a);
+            struct ggml_tensor  * a,
+            float                 eps);
 
     GGML_API struct ggml_tensor * ggml_norm_inplace(
             struct ggml_context * ctx,
-            struct ggml_tensor  * a);
+            struct ggml_tensor  * a,
+            float                 eps);
 
     GGML_API struct ggml_tensor * ggml_rms_norm(
             struct ggml_context * ctx,

+ 13 - 13
gguf.py

@@ -30,12 +30,12 @@ KEY_GENERAL_SOURCE_HF_REPO       = "general.source.hugginface.repository"
 KEY_GENERAL_FILE_TYPE            = "general.file_type"
 
 # LLM
-KEY_LLM_CONTEXT_LENGTH        = "{arch}.context_length"
-KEY_LLM_EMBEDDING_LENGTH      = "{arch}.embedding_length"
-KEY_LLM_BLOCK_COUNT           = "{arch}.block_count"
-KEY_LLM_FEED_FORWARD_LENGTH   = "{arch}.feed_forward_length"
-KEY_LLM_USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
-KEY_LLM_TENSOR_DATA_LAYOUT    = "{arch}.tensor_data_layout"
+KEY_CONTEXT_LENGTH        = "{arch}.context_length"
+KEY_EMBEDDING_LENGTH      = "{arch}.embedding_length"
+KEY_BLOCK_COUNT           = "{arch}.block_count"
+KEY_FEED_FORWARD_LENGTH   = "{arch}.feed_forward_length"
+KEY_USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
+KEY_TENSOR_DATA_LAYOUT    = "{arch}.tensor_data_layout"
 
 # attention
 KEY_ATTENTION_HEAD_COUNT        = "{arch}.attention.head_count"
@@ -583,7 +583,7 @@ class GGUFWriter:
         self.add_string(KEY_GENERAL_AUTHOR, author)
 
     def add_tensor_data_layout(self, layout: str):
-        self.add_string(KEY_LLM_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
+        self.add_string(KEY_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
 
     def add_url(self, url: str):
         self.add_string(KEY_GENERAL_URL, url)
@@ -613,27 +613,27 @@ class GGUFWriter:
 
     def add_context_length(self, length: int):
         self.add_uint32(
-            KEY_LLM_CONTEXT_LENGTH.format(arch=self.arch), length)
+            KEY_CONTEXT_LENGTH.format(arch=self.arch), length)
 
     def add_embedding_length(self, length: int):
         self.add_uint32(
-            KEY_LLM_EMBEDDING_LENGTH.format(arch=self.arch), length)
+            KEY_EMBEDDING_LENGTH.format(arch=self.arch), length)
 
     def add_block_count(self, length: int):
         self.add_uint32(
-            KEY_LLM_BLOCK_COUNT.format(arch=self.arch), length)
+            KEY_BLOCK_COUNT.format(arch=self.arch), length)
 
     def add_feed_forward_length(self, length: int):
         self.add_uint32(
-            KEY_LLM_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
+            KEY_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
 
     def add_parallel_residual(self, use: bool):
         self.add_bool(
-            KEY_LLM_USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
+            KEY_USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
 
     def add_tensor_data_layout(self, layout: str):
         self.add_string(
-            KEY_LLM_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
+            KEY_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
 
     def add_head_count(self, count: int):
         self.add_uint32(

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 654 - 258
llama.cpp


+ 2 - 13
llama.h

@@ -247,6 +247,8 @@ extern "C" {
     LLAMA_API int llama_n_ctx  (const struct llama_context * ctx);
     LLAMA_API int llama_n_embd (const struct llama_context * ctx);
 
+    LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx);
+
     LLAMA_API int llama_model_n_vocab(const struct llama_model * model);
     LLAMA_API int llama_model_n_ctx  (const struct llama_model * model);
     LLAMA_API int llama_model_n_embd (const struct llama_model * model);
@@ -368,13 +370,6 @@ extern "C" {
                              int   n_max_tokens,
                             bool   add_bos);
 
-    LLAMA_API int llama_tokenize_bpe(
-            struct llama_context * ctx,
-                      const char * text,
-                     llama_token * tokens,
-                             int   n_max_tokens,
-                            bool   add_bos);
-
     LLAMA_API int llama_tokenize_with_model(
         const struct llama_model * model,
                       const char * text,
@@ -390,12 +385,6 @@ extern "C" {
                                   char * buf,
                                   int    length);
 
-    LLAMA_API int llama_token_to_str_bpe(
-            const struct llama_context * ctx,
-                           llama_token   token,
-                                  char * buf,
-                                  int    length);
-
     LLAMA_API int llama_token_to_str_with_model(
               const struct llama_model * model,
                            llama_token   token,

+ 2 - 1
tests/CMakeLists.txt

@@ -28,7 +28,8 @@ llama_build_and_test_executable(test-sampling.cpp)
 llama_build_executable(test-tokenizer-0.cpp)
 llama_test_executable (test-tokenizer-0.llama test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
 llama_build_executable(test-tokenizer-1.cpp)
-llama_test_executable (test-tokenizer-1.llama test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
+# test-tokenizer-1 requires a BPE vocab. re-enable when we have one.
+#llama_test_executable (test-tokenizer-1.llama test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
 #llama_test_executable(test-tokenizer-1.aquila test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
 llama_build_and_test_executable(test-grammar-parser.cpp)
 llama_build_and_test_executable(test-llama-grammar.cpp)

+ 4 - 12
tests/test-tokenizer-1.cpp

@@ -67,11 +67,13 @@ int main(int argc, char **argv) {
         }
     }
 
+    GGML_ASSERT(llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_BPE);
+
     const int n_vocab = llama_n_vocab(ctx);
 
     for (int i = 0; i < n_vocab; ++i) {
-        std::string forward = llama_token_to_str_bpe(ctx, i);
-        std::vector<llama_token> tokens = llama_tokenize_bpe(ctx, forward, false);
+        std::string forward = llama_token_to_str(ctx, i);
+        std::vector<llama_token> tokens = llama_tokenize(ctx, forward, false);
         if (tokens.size() == 1) {
             if (i != tokens[0]) {
                 std::string backward = llama_token_to_str(ctx, tokens[0]);
@@ -79,16 +81,6 @@ int main(int argc, char **argv) {
                     __func__, i, llama_token_to_str(ctx, i).c_str(), tokens[0], backward.c_str());
                 return 2;
             }
-        } else {
-            llama_token_type type = llama_token_get_type(ctx, i);
-            if (type == LLAMA_TOKEN_TYPE_UNKNOWN || type == LLAMA_TOKEN_TYPE_CONTROL || type == LLAMA_TOKEN_TYPE_BYTE) {
-                fprintf(stderr, "%s : info: token %d is string %s and bpe returns tokens %s\n",
-                    __func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens).c_str());
-            } else {
-                fprintf(stderr, "%s : error: token %d is string %s but bpe returns tokens %s\n",
-                    __func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens).c_str());
-                return 2;
-            }
         }
     }
 

Kaikkia tiedostoja ei voida näyttää, sillä liian monta tiedostoa muuttui tässä diffissä