Bläddra i källkod

Make loading weights 10-100x faster

This is a breaking change that's going to give you three benefits:

1. Your inference commands should load 100x faster
2. You may be able to safely load models 2x larger
3. You can run many concurrent inference processes

This was accomplished by changing the file format so we can mmap()
weights directly into memory without having to read() or copy them
thereby ensuring the kernel can make its file cache pages directly
accessible to our inference processes; and secondly, that the file
cache pages are much less likely to get evicted (which would force
loads to hit disk) because they're no longer competing with memory
pages that were needlessly created by gigabytes of standard i/o.

The new file format supports single-file models like LLaMA 7b, and
it also supports multi-file models like LLaMA 13B. Our Python tool
now merges the foo.1, foo.2, etc. files back into a single file so
that the C++ code which maps it doesn't need to reshape data every
time. That's made llama.cpp so much simpler. Much of its load code
has now been deleted.

Furthermore, this change ensures that tensors are aligned properly
on a 32-byte boundary. That opens the door to seeing if we can get
additional performance gains on some microprocessors, by using ops
that require memory alignment.

Lastly note that both POSIX and the Windows platform are supported

Fixes #91
Justine Tunney 2 år sedan
förälder
incheckning
78ca9838ee
7 ändrade filer med 334 tillägg och 373 borttagningar
  1. 1 0
      .gitignore
  2. 5 0
      convert-ggml-to-pth.py
  3. 5 0
      convert-gptq-to-ggml.py
  4. 149 52
      convert-pth-to-ggml.py
  5. 173 320
      llama.cpp
  6. 1 1
      llama.h
  7. BIN
      models/ggml-vocab.bin

+ 1 - 0
.gitignore

@@ -22,6 +22,7 @@ models/*
 /result
 /result
 /perplexity
 /perplexity
 /embedding
 /embedding
+/Pipfile
 
 
 arm_neon.h
 arm_neon.h
 compile_commands.json
 compile_commands.json

+ 5 - 0
convert-ggml-to-pth.py

@@ -84,6 +84,11 @@ def read_variables(fin):
         shape = shape[::-1]
         shape = shape[::-1]
         name = fin.read(name_length).decode("utf-8")
         name = fin.read(name_length).decode("utf-8")
 
 
+        # ensure tensor data is aligned
+        tensor_data_offset = fin.tell()
+        tensor_data_offset = (tensor_data_offset + 31) & -32
+        fin.seek(tensor_data_offset)
+
         if ftype_cur == 2:
         if ftype_cur == 2:
             # 4-bit quantized weights
             # 4-bit quantized weights
             dtype = np.uint8
             dtype = np.uint8

+ 5 - 0
convert-gptq-to-ggml.py

@@ -72,6 +72,11 @@ def write_header(shape, dst_name, ftype_cur):
     fout.write(struct.pack("i" * len(shape), *shape[::-1]))
     fout.write(struct.pack("i" * len(shape), *shape[::-1]))
     fout.write(sname)
     fout.write(sname)
 
 
+    # ensure tensor data is aligned
+    tensor_data_offset = fout.tell()
+    tensor_data_offset = (tensor_data_offset + 31) & -32
+    fout.seek(tensor_data_offset)
+
 def convert_non_q4(src_name, dst_name):
 def convert_non_q4(src_name, dst_name):
     v = model[src_name]
     v = model[src_name]
     shape = v.shape
     shape = v.shape

+ 149 - 52
convert-pth-to-ggml.py

@@ -24,8 +24,57 @@ import torch
 
 
 from sentencepiece import SentencePieceProcessor
 from sentencepiece import SentencePieceProcessor
 
 
-def parse_args():
+QK = 32
+
+GGML_TYPE_Q4_0  = 0
+GGML_TYPE_Q4_1  = 1
+GGML_TYPE_I8    = 2
+GGML_TYPE_I16   = 3
+GGML_TYPE_I32   = 4
+GGML_TYPE_F16   = 5
+GGML_TYPE_F32   = 6
+
+WTYPES = {
+    0: GGML_TYPE_F32,
+    1: GGML_TYPE_F16,
+    2: GGML_TYPE_Q4_0,
+    3: GGML_TYPE_Q4_1,
+}
+
+GGML_BLCK_SIZE = {
+    GGML_TYPE_Q4_0:  QK,
+    GGML_TYPE_Q4_1:  QK,
+    GGML_TYPE_I8:    1,
+    GGML_TYPE_I16:   1,
+    GGML_TYPE_I32:   1,
+    GGML_TYPE_F16:   1,
+    GGML_TYPE_F32:   1,
+}
+
+GGML_TYPE_SIZE = {
+    GGML_TYPE_Q4_0: 4   + QK/2,
+    GGML_TYPE_Q4_1: 4*2 + QK/2,
+    GGML_TYPE_I8:   1,
+    GGML_TYPE_I16:  2,
+    GGML_TYPE_I32:  4,
+    GGML_TYPE_F16:  2,
+    GGML_TYPE_F32:  4,
+}
+
+def ggml_nelements(shape):
+    r = 1
+    for i in shape:
+        r *= i
+    return r
+
+def ggml_nbytes(shape, ftype):
+    x = ggml_nelements(shape)
+    t = WTYPES[ftype]
+    x *= GGML_TYPE_SIZE[t]
+    x //= GGML_BLCK_SIZE[t]
+    return x
 
 
+def parse_args():
     parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file')
     parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file')
     parser.add_argument('dir_model',  help='directory containing the model checkpoint')
     parser.add_argument('dir_model',  help='directory containing the model checkpoint')
     parser.add_argument('ftype',      help='file type (0: float32, 1: float16)', type=int, choices=[0, 1], default=1)
     parser.add_argument('ftype',      help='file type (0: float32, 1: float16)', type=int, choices=[0, 1], default=1)
@@ -33,7 +82,6 @@ def parse_args():
     return parser.parse_args()
     return parser.parse_args()
 
 
 def get_n_parts(dim):
 def get_n_parts(dim):
-
     mappings = {4096: 1, 5120: 2, 6656: 4, 8192: 8}
     mappings = {4096: 1, 5120: 2, 6656: 4, 8192: 8}
     n_parts = mappings.get(dim)
     n_parts = mappings.get(dim)
     if n_parts is None:
     if n_parts is None:
@@ -44,30 +92,24 @@ def get_n_parts(dim):
     return n_parts
     return n_parts
 
 
 def load_hparams_and_tokenizer(dir_model):
 def load_hparams_and_tokenizer(dir_model):
-
     # `dir_model` is something like `models/7B` or `models/7B/`.
     # `dir_model` is something like `models/7B` or `models/7B/`.
     # "tokenizer.model" is expected under model's parent dir.
     # "tokenizer.model" is expected under model's parent dir.
     # When `dir_model` is a symlink, f"{dir_model}/../tokenizer.model" would not be found.
     # When `dir_model` is a symlink, f"{dir_model}/../tokenizer.model" would not be found.
     # Let's use the model's parent dir directly.
     # Let's use the model's parent dir directly.
     model_parent_dir = os.path.dirname(os.path.normpath(dir_model))
     model_parent_dir = os.path.dirname(os.path.normpath(dir_model))
-
     fname_hparams = f"{dir_model}/params.json"
     fname_hparams = f"{dir_model}/params.json"
     fname_tokenizer = f"{model_parent_dir}/tokenizer.model"
     fname_tokenizer = f"{model_parent_dir}/tokenizer.model"
-
     with open(fname_hparams, "r") as f:
     with open(fname_hparams, "r") as f:
         hparams = json.load(f)
         hparams = json.load(f)
         print(hparams)
         print(hparams)
-
     tokenizer = SentencePieceProcessor(fname_tokenizer)
     tokenizer = SentencePieceProcessor(fname_tokenizer)
     hparams.update({"vocab_size": tokenizer.vocab_size()})
     hparams.update({"vocab_size": tokenizer.vocab_size()})
-
     return hparams, tokenizer
     return hparams, tokenizer
 
 
 def write_header(fout, hparams, ftype):
 def write_header(fout, hparams, ftype):
-
     keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"]
     keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"]
     values = [
     values = [
-        0x67676d66,  # magic: ggmf in hex
+        0x67676a74,  # magic: ggjt in hex
         1, # file version
         1, # file version
         *[hparams[key] for key in keys],
         *[hparams[key] for key in keys],
         hparams["dim"] // hparams["n_heads"],  # rot (obsolete)
         hparams["dim"] // hparams["n_heads"],  # rot (obsolete)
@@ -76,7 +118,6 @@ def write_header(fout, hparams, ftype):
     fout.write(struct.pack("i" * len(values), *values))
     fout.write(struct.pack("i" * len(values), *values))
 
 
 def write_tokens(fout, tokenizer):
 def write_tokens(fout, tokenizer):
-
     for i in range(tokenizer.vocab_size()):
     for i in range(tokenizer.vocab_size()):
         if tokenizer.is_unknown(i):
         if tokenizer.is_unknown(i):
             text = " \u2047 ".encode("utf-8")
             text = " \u2047 ".encode("utf-8")
@@ -95,85 +136,141 @@ def write_tokens(fout, tokenizer):
         fout.write(text)
         fout.write(text)
         fout.write(struct.pack("f", tokenizer.get_score(i)))
         fout.write(struct.pack("f", tokenizer.get_score(i)))
 
 
-def process_and_write_variables(fout, model, ftype):
-
+def process_and_write_variables(fout, model, ftype, part_id, n_parts):
     for name, datao in model.items():
     for name, datao in model.items():
-
         if name.endswith("freqs"):
         if name.endswith("freqs"):
             continue
             continue
 
 
-        shape = datao.shape
-
-        print(f"Processing variable: {name} with shape: {shape} and type: {datao.dtype}")
-
+        # remove dimensions with a single element
         data = datao.numpy().squeeze()
         data = datao.numpy().squeeze()
-        n_dims = len(shape)
+        partshape = data.shape
+        n_dims = len(data.shape)
+        assert n_dims in (1, 2)
+
+        print(f"Processing variable: {name} with shape: {partshape} and type: {datao.dtype}")
 
 
-        # default type is fp16
+        # coerce single-dimensional tensors from float16 to float32
         ftype_cur = 1
         ftype_cur = 1
         if ftype == 0 or n_dims == 1:
         if ftype == 0 or n_dims == 1:
             print("  Converting to float32")
             print("  Converting to float32")
             data = data.astype(np.float32)
             data = data.astype(np.float32)
             ftype_cur = 0
             ftype_cur = 0
-
-        # header
+        blck_size = GGML_BLCK_SIZE[WTYPES[ftype_cur]]
+        type_size = GGML_TYPE_SIZE[WTYPES[ftype_cur]]
+
+        # determine dimension along which multipart tensor is sharded
+        #
+        # split_dim 0 regex:
+        #   - output.*
+        #   - layers.*.attention.wq.weight
+        #   - layers.*.attention.wk.weight
+        #   - layers.*.attention.wv.weight
+        #   - layers.*.feed_forward.w1.weight
+        #   - layers.*.feed_forward.w3.weight
+        #
+        # split_dim 1 regex:
+        #   - tok_embeddings.*
+        #   - layers.*.attention.wo.weight
+        #   - layers.*.feed_forward.w2.weight
+        #
+        if n_dims > 1:
+            split_dim = 1
+            if "tok_embeddings" in name:
+                split_dim = 1
+            elif "layers" in name:
+                if "attention.wo.weight" in name:
+                    split_dim = 1
+                elif "feed_forward.w2.weight" in name:
+                    split_dim = 1
+                else:
+                    split_dim = 0
+            elif "output" in name:
+                split_dim = 0
+
+        # output tensor header
+        fullshape = list(partshape)
+        if n_dims > 1:
+            fullshape[split_dim] *= n_parts
         sname = name.encode('utf-8')
         sname = name.encode('utf-8')
-        fout.write(struct.pack("iii", len(data.shape), len(sname), ftype_cur))
-        for dim in reversed(data.shape):
+        fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur))
+        for dim in reversed(fullshape):
             fout.write(struct.pack("i", dim))
             fout.write(struct.pack("i", dim))
         fout.write(sname)
         fout.write(sname)
 
 
-        # data output to file
-        data.tofile(fout)
+        # ensure tensor data is aligned
+        tensor_data_offset = fout.tell()
+        while tensor_data_offset % QK != 0:
+            fout.write(struct.pack("B", 0))
+            tensor_data_offset += 1
+
+        # output unified mappable tensor data
+        if n_dims == 1 or n_parts == 1:
+            # copy tensor which we thankfully received in one piece
+            if part_id == 0:
+                data.tofile(fout)
+        elif split_dim == 0:
+            # reassemble multifile tensor containing some of the rows
+            rows_per_chunk = partshape[0]
+            current_row = part_id * rows_per_chunk
+            bytes_per_row = fullshape[1] // blck_size * type_size
+            offset = current_row * bytes_per_row
+            fout.seek(tensor_data_offset + offset)
+            data.tofile(fout)
+        elif split_dim == 1:
+            # reassemble multifile tensor containing some of the cols
+            cols_per_chunk = partshape[1]
+            current_col = part_id * cols_per_chunk
+            bytes_per_row = fullshape[1] // blck_size * type_size
+            offset_current_col = current_col // blck_size * type_size
+            for row in range(partshape[0]):
+                offset_row = row * bytes_per_row
+                offset = offset_row + offset_current_col
+                fout.seek(tensor_data_offset + offset)
+                data[row].tofile(fout)
+
+        # advance file position to next tensor
+        fout.seek(tensor_data_offset + ggml_nbytes(fullshape, ftype_cur))
 
 
 def main():
 def main():
-
     args = parse_args()
     args = parse_args()
     dir_model = args.dir_model
     dir_model = args.dir_model
     ftype = args.ftype
     ftype = args.ftype
     ftype_str = ["f32", "f16"]
     ftype_str = ["f32", "f16"]
-
     hparams, tokenizer = load_hparams_and_tokenizer(dir_model)
     hparams, tokenizer = load_hparams_and_tokenizer(dir_model)
 
 
     print(args)
     print(args)
 
 
     # if only writing vocab to file
     # if only writing vocab to file
     if args.vocab_only:
     if args.vocab_only:
-
         fname_model = f"{dir_model}/consolidated.00.pth"
         fname_model = f"{dir_model}/consolidated.00.pth"
         fname_out = f"{dir_model}/ggml-vocab.bin"
         fname_out = f"{dir_model}/ggml-vocab.bin"
-
         print(f"Extracting only the vocab from '{fname_model}'\n")
         print(f"Extracting only the vocab from '{fname_model}'\n")
-
-
+        model = torch.load(fname_model, map_location="cpu")
         with open(fname_out, "wb") as fout:
         with open(fname_out, "wb") as fout:
             write_header(fout, hparams, ftype)
             write_header(fout, hparams, ftype)
             write_tokens(fout, tokenizer)
             write_tokens(fout, tokenizer)
-
-
+        del model
         print(f"Done. Output file: {fname_out}\n")
         print(f"Done. Output file: {fname_out}\n")
-
         return
         return
 
 
     n_parts = get_n_parts(hparams["dim"])
     n_parts = get_n_parts(hparams["dim"])
-
-    for p in range(n_parts):
-
-        print(f"Processing part {p+1} of {n_parts}\n")
-
-        fname_model = f"{dir_model}/consolidated.0{p}.pth"
-        fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin{'' if p == 0 else '.' + str(p)}"
-
-        model = torch.load(fname_model, map_location="cpu")
-
-        with open(fname_out, "wb") as fout:
-            write_header(fout, hparams, ftype)
-            write_tokens(fout, tokenizer)
-            process_and_write_variables(fout, model, ftype)
-
-        del model
-
-        print(f"Done. Output file: {fname_out}, (part {p})\n")
+    fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin"
+
+    # we output a single file for ggml
+    with open(fname_out, "wb") as fout:
+        write_header(fout, hparams, ftype)
+        write_tokens(fout, tokenizer)
+        offset_of_tensors = fout.tell()
+        # the tensors we load could be split across multiple files
+        for part_id in range(n_parts):
+            fout.seek(offset_of_tensors)
+            print(f"Processing part {part_id+1} of {n_parts}\n")
+            fname_model = f"{dir_model}/consolidated.0{part_id}.pth"
+            model = torch.load(fname_model, map_location="cpu")
+            process_and_write_variables(fout, model, ftype, part_id, n_parts)
+            del model
+
+    print(f"Done. Output file: {fname_out}\n")
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     main()
     main()

+ 173 - 320
llama.cpp

@@ -12,17 +12,19 @@
 #include <cassert>
 #include <cassert>
 #include <cstring>
 #include <cstring>
 
 
-// mmap
-#if defined (__unix__) || defined (__APPLE__)
-#   include <sys/mman.h>
-#   include <fcntl.h>
-#   include <unistd.h>
-#elif defined(_WIN32)
-#   define WIN32_LEAN_AND_MEAN
-#   include <Windows.h>
-//#include <Memoryapi.h>
+#if defined(_WIN32) && !defined(_POSIX_MAPPED_FILES)
+#define WIN32_LEAN_AND_MEAN
+#include <Windows.h>
+#else
+#include <sys/types.h>
+#include <sys/mman.h>
+#include <unistd.h>
+#include <fcntl.h>
 #endif
 #endif
 
 
+#define Min(X, Y) ((Y) > (X) ? (X) : (Y))
+#define Max(X, Y) ((Y) < (X) ? (X) : (Y))
+
 #define LLAMA_USE_SCRATCH
 #define LLAMA_USE_SCRATCH
 #define LLAMA_MAX_SCRATCH_BUFFERS 16
 #define LLAMA_MAX_SCRATCH_BUFFERS 16
 
 
@@ -155,7 +157,7 @@ struct llama_model {
 
 
     // model memory mapped file
     // model memory mapped file
     void * mm_addr = NULL;
     void * mm_addr = NULL;
-    size_t mm_length = 0;
+    uint64_t mm_length = 0;
 
 
     // tensors
     // tensors
     int n_loaded;
     int n_loaded;
@@ -180,6 +182,7 @@ struct llama_context {
 
 
     int64_t t_load_us = 0;
     int64_t t_load_us = 0;
     int64_t t_start_us = 0;
     int64_t t_start_us = 0;
+    bool has_evaluated_once = false;
 
 
     int64_t t_sample_us = 0;
     int64_t t_sample_us = 0;
     int64_t t_eval_us   = 0;
     int64_t t_eval_us   = 0;
@@ -221,7 +224,7 @@ struct llama_context {
         }
         }
 
 
         if (buf_last >= 0) {
         if (buf_last >= 0) {
-            buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
+            buf_max_size[buf_last] = Max(buf_max_size[buf_last], last_size);
         }
         }
 
 
         buf_last = i;
         buf_last = i;
@@ -304,59 +307,57 @@ struct llama_context_params llama_context_default_params() {
 // model loading
 // model loading
 //
 //
 
 
-static void mmap_file(const char* fname, void * &mm_addr, size_t &mm_length) {
-#if defined(MAP_FAILED)
-    // POSIX
-    int fd = open(fname, O_RDONLY);
-    mm_length = lseek(fd, 0, SEEK_END);
-    mm_addr = mmap(NULL, mm_length, PROT_READ, MAP_SHARED, fd, 0);
-    close(fd);
-    if (mm_addr == MAP_FAILED) {
-        perror("mmap failed");
-        mm_addr = NULL;
-        mm_length = 0;
-    }
-#elif defined(_WIN32)
-    mm_addr = NULL;
-    
-    HANDLE hFile = CreateFileA(filename, GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
-    if (hFile == INVALID_HANDLE_VALUE) {
-        return;
-    }
-
-    // not really necessary
+static void *mmap_file(const char *fname, uint64_t *mm_length) {
+#if defined(_WIN32) && !defined(_POSIX_MAPPED_FILES)
+    HANDLE hFile = CreateFileA(fname,
+                               GENERIC_READ,
+                               FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE,
+                               NULL,
+                               OPEN_EXISTING,
+                               FILE_ATTRIBUTE_NORMAL | FILE_ATTRIBUTE_NOT_CONTENT_INDEXED,
+                               NULL);
+    if (hFile == INVALID_HANDLE_VALUE) return 0;
     LARGE_INTEGER fileSize;
     LARGE_INTEGER fileSize;
+    fileSize.QuadPart = -1;
     GetFileSizeEx(hFile, &fileSize);
     GetFileSizeEx(hFile, &fileSize);
-    mm_length = fileSize;
-
+    int64_t length = fileSize.QuadPart;
     HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
     HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
     CloseHandle(hFile);
     CloseHandle(hFile);
-
-    if (hMapping == NULL) {
-        return;
-    }
-
-    mm_addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
+    if (!hMapping) return 0;
+    void *addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
     CloseHandle(hMapping);
     CloseHandle(hMapping);
+    if (!addr) return 0;
 #else
 #else
-    mm_addr = NULL;
-    mm_length = 0;
-    (void)(fname); // suppress warnings
+    int fd = open(fname, O_RDONLY);
+    if (fd == -1) return 0;
+    int64_t length = lseek(fd, 0, SEEK_END);
+    void *addr = mmap(NULL, length, PROT_READ, MAP_SHARED, fd, 0);
+    close(fd);
+    if (addr == MAP_FAILED) return 0;
 #endif
 #endif
+    *mm_length = length;
+    return addr;
 }
 }
 
 
 static void munmap_file(void * addr, size_t length) {
 static void munmap_file(void * addr, size_t length) {
-#if defined(MAP_FAILED)
-    // POSIX
-    munmap(addr, length);
-#elif defined(_WIN32)
+#if defined(_WIN32) && !defined(_POSIX_MAPPED_FILES)
     UnmapViewOfFile(addr);
     UnmapViewOfFile(addr);
 #else
 #else
-    (void)(addr); // suppress warnings
-    (void)(length);
+    munmap(addr, length);
 #endif
 #endif
 }
 }
 
 
+static bool report_bad_magic(const char *path) {
+    fprintf(stderr,
+            "%s: invalid model file (bad magic)\n"
+            "you most likely need to regenerate your ggml files\n"
+            "the benefit is you'll get 10-100x faster load times\n"
+            "see https://github.com/ggerganov/llama.cpp/issues/91\n"
+            "use convert-pth-to-ggml.py on your llama model files\n",
+            path);
+    return false;
+}
+
 static bool llama_model_load(
 static bool llama_model_load(
         const std::string & fname,
         const std::string & fname,
         llama_context & lctx,
         llama_context & lctx,
@@ -368,23 +369,24 @@ static bool llama_model_load(
         void *progress_callback_user_data) {
         void *progress_callback_user_data) {
     fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
     fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
 
 
-    const int64_t t_start_us = ggml_time_us();
-
-    lctx.t_start_us = t_start_us;
-
-    // TODO: this could probably be smaller when using mmap
-    std::vector<char> f_buf(1024*1024);
+    lctx.t_start_us = ggml_time_us();
 
 
     auto & model = lctx.model;
     auto & model = lctx.model;
     auto & vocab = lctx.vocab;
     auto & vocab = lctx.vocab;
 
 
     auto fin = std::ifstream(fname, std::ios::binary);
     auto fin = std::ifstream(fname, std::ios::binary);
-    fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
     if (!fin) {
     if (!fin) {
         fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
         fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
         return false;
         return false;
     }
     }
 
 
+    std::vector<char> f_buf(1024*1024);
+    fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
+
+    fin.seekg(0, fin.end);
+    const size_t file_size = fin.tellg();
+    fin.seekg(0);
+
     // verify magic
     // verify magic
     {
     {
         uint32_t magic;
         uint32_t magic;
@@ -395,8 +397,7 @@ static bool llama_model_load(
             return false;
             return false;
         }
         }
         if (magic != LLAMA_FILE_MAGIC) {
         if (magic != LLAMA_FILE_MAGIC) {
-            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
-            return false;
+            return report_bad_magic(fname.c_str());
         }
         }
 
 
         uint32_t format_version;
         uint32_t format_version;
@@ -519,54 +520,24 @@ static bool llama_model_load(
                 }
                 }
     }
     }
 
 
-    bool use_mmap = (n_parts == 1);
-
-    // try to memory map the model file
-    void * mm_addr = NULL;
-    if (use_mmap) {
-        mmap_file(fname.c_str(), model.mm_addr, model.mm_length);
-        if (model.mm_addr == NULL) {
-            use_mmap = false;
-        }
-        else {
-            mm_addr = model.mm_addr;
-        }
+    // map model into memory
+    char *mm_addr = NULL;
+    model.mm_addr = mmap_file(fname.c_str(), &model.mm_length);
+    if (model.mm_addr == NULL) {
+        fprintf(stderr, "%s: failed to mmap '%s'\n", __func__, fname.c_str());
+        return false;
     }
     }
+    mm_addr = (char *)model.mm_addr;
+    fprintf(stderr, "%s: ggml map size = %6.2f MB\n", __func__, model.mm_length/(1024.0*1024.0));
 
 
     auto & ctx = model.ctx;
     auto & ctx = model.ctx;
 
 
     size_t ctx_size = 0;
     size_t ctx_size = 0;
     {
     {
-        const auto & hparams = model.hparams;
-
-        const int n_embd  = hparams.n_embd;
+        const auto &hparams = model.hparams;
         const int n_layer = hparams.n_layer;
         const int n_layer = hparams.n_layer;
-        const int n_vocab = hparams.n_vocab;
-
-        if (!use_mmap) {
-            ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // tok_embeddings
-
-            ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // norm
-
-            ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // output
-
-            ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // attention_norm
-
-            ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wq
-            ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wk
-            ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wv
-            ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wo
-
-            ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ffn_norm
-
-            ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w1
-            ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w2
-            ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w3
-        }
-
         ctx_size += (5 + 10*n_layer)*256; // object overhead
         ctx_size += (5 + 10*n_layer)*256; // object overhead
-
-        fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
+        fprintf(stderr, "%s: ggml ctx size = %6.2f KB\n", __func__, ctx_size/1024.0);
     }
     }
 
 
     // print memory requirements
     // print memory requirements
@@ -576,6 +547,7 @@ static bool llama_model_load(
         // this is the total memory required to run the inference
         // this is the total memory required to run the inference
         const size_t mem_required =
         const size_t mem_required =
             ctx_size +
             ctx_size +
+            model.mm_length +
             MEM_REQ_SCRATCH0.at(model.type) +
             MEM_REQ_SCRATCH0.at(model.type) +
             MEM_REQ_SCRATCH1.at(model.type) +
             MEM_REQ_SCRATCH1.at(model.type) +
             MEM_REQ_EVAL.at    (model.type);
             MEM_REQ_EVAL.at    (model.type);
@@ -595,7 +567,7 @@ static bool llama_model_load(
         struct ggml_init_params params = {
         struct ggml_init_params params = {
             /*.mem_size   =*/ lctx.model.buf.size(),
             /*.mem_size   =*/ lctx.model.buf.size(),
             /*.mem_buffer =*/ lctx.model.buf.data(),
             /*.mem_buffer =*/ lctx.model.buf.data(),
-            /*.no_alloc   =*/ use_mmap,
+            /*.no_alloc   =*/ true,
         };
         };
 
 
         model.ctx = ggml_init(params);
         model.ctx = ggml_init(params);
@@ -658,241 +630,106 @@ static bool llama_model_load(
         }
         }
     }
     }
 
 
-    const size_t file_offset = fin.tellg();
-
-    fin.close();
-
     std::vector<uint8_t> tmp;
     std::vector<uint8_t> tmp;
 
 
     if (progress_callback) {
     if (progress_callback) {
         progress_callback(0.0, progress_callback_user_data);
         progress_callback(0.0, progress_callback_user_data);
     }
     }
 
 
-    for (int i = 0; i < n_parts; ++i) {
-        const int part_id = i;
-        //const int part_id = n_parts - i - 1;
-
-        std::string fname_part = fname;
-        if (i > 0) {
-            fname_part += "." + std::to_string(i);
-        }
+    fprintf(stderr, "%s: loading tensors from '%s'\n", __func__, fname.c_str());
 
 
-        fprintf(stderr, "%s: loading model part %d/%d from '%s'%s\n", __func__, i+1, n_parts, fname_part.c_str(), use_mmap ? " (memory mapped)" : "");
+    // load weights
+    {
+        size_t total_size = 0;
+        model.n_loaded = 0;
 
 
-        fin = std::ifstream(fname_part, std::ios::binary);
-        fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
+        while (true) {
+            int32_t n_dims;
+            int32_t length;
+            int32_t ftype;
 
 
-        fin.seekg(0, fin.end);
-        const size_t file_size = fin.tellg();
+            fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+            fin.read(reinterpret_cast<char *>(&length), sizeof(length));
+            fin.read(reinterpret_cast<char *>(&ftype),  sizeof(ftype));
 
 
-        fin.seekg(file_offset);
+            if (fin.eof()) {
+                break;
+            }
 
 
-        // load weights
-        {
-            size_t total_size = 0;
+            int32_t nelements = 1;
+            int32_t ne[2] = { 1, 1 };
+            for (int i = 0; i < n_dims; ++i) {
+                fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+                nelements *= ne[i];
+            }
 
 
-            model.n_loaded = 0;
+            std::string name(length, 0);
+            fin.read(&name[0], length);
 
 
-            fprintf(stderr, "%s: ", __func__);
+            if (model.tensors.find(name.data()) == model.tensors.end()) {
+                fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
+                return false;
+            }
 
 
-            while (true) {
-                int32_t n_dims;
-                int32_t length;
-                int32_t ftype;
+            auto tensor = model.tensors[name.data()];
 
 
-                fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
-                fin.read(reinterpret_cast<char *>(&length), sizeof(length));
-                fin.read(reinterpret_cast<char *>(&ftype),  sizeof(ftype));
+            if (ggml_nelements(tensor) != nelements) {
+                fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
+                return false;
+            }
+            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
+                fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
+                        __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
+                return false;
+            }
+            if (0) {
+                static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
+                fprintf(stderr, "%24s - [%5d, %5d], type = %6s\n", name.data(), ne[0], ne[1], ftype_str[ftype]);
+            }
 
 
-                if (fin.eof()) {
+            switch (ftype) {
+                case 0:  // f32
+                case 1:  // f16
                     break;
                     break;
-                }
-
-                int32_t nelements = 1;
-                int32_t ne[2] = { 1, 1 };
-                for (int i = 0; i < n_dims; ++i) {
-                    fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
-                    nelements *= ne[i];
-                }
-
-                std::string name(length, 0);
-                fin.read(&name[0], length);
-
-                if (model.tensors.find(name.data()) == model.tensors.end()) {
-                    fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
+                case 2:  // q4_0
+                case 3:  // q4_1
+                    assert(ne[0] % 64 == 0);
+                    break;
+                default:
+                    fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
                     return false;
                     return false;
-                }
-
-                // split_type = 0: split by columns
-                // split_type = 1: split by rows
-                int split_type = 0;
-
-                // split_type = 0:
-                // regex:
-                //   - tok_embeddings.*
-                //   - layers.*.attention.wo.weight
-                //   - layers.*.feed_forward.w2.weight
-
-                // split_type = 1:
-                // regex:
-                //   - output.*
-                //   - layers.*.attention.wq.weight
-                //   - layers.*.attention.wk.weight
-                //   - layers.*.attention.wv.weight
-                //   - layers.*.feed_forward.w1.weight
-                //   - layers.*.feed_forward.w3.weight
-                if (name.find("tok_embeddings") != std::string::npos) {
-                    split_type = 0;
-                } else if (name.find("layers") != std::string::npos) {
-                    if (name.find("attention.wo.weight") != std::string::npos) {
-                        split_type = 0;
-                    } else if (name.find("feed_forward.w2.weight") != std::string::npos) {
-                        split_type = 0;
-                    } else {
-                        split_type = 1;
-                    }
-                } else if (name.find("output") != std::string::npos) {
-                    split_type = 1;
-                }
-
-                auto tensor = model.tensors[name.data()];
-
-                if (n_dims == 1) {
-                    if (ggml_nelements(tensor) != nelements) {
-                        fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
-                        return false;
-                    }
-                } else {
-                    if (ggml_nelements(tensor)/n_parts != nelements) {
-                        fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
-                        return false;
-                    }
-                }
-
-                if (n_dims == 1) {
-                    if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
-                        fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
-                                __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
-                        return false;
-                    }
-                } else {
-                    if (split_type == 0) {
-                        if (tensor->ne[0]/n_parts != ne[0] || tensor->ne[1] != ne[1]) {
-                            fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
-                                    __func__, name.data(), tensor->ne[0]/n_parts, tensor->ne[1], ne[0], ne[1]);
-                            return false;
-                        }
-                    } else {
-                        if (tensor->ne[0] != ne[0] || tensor->ne[1]/n_parts != ne[1]) {
-                            fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
-                                    __func__, name.data(), tensor->ne[0], tensor->ne[1]/n_parts, ne[0], ne[1]);
-                            return false;
-                        }
-                    }
-                }
-
-                if (0) {
-                    static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
-                    fprintf(stderr, "%24s - [%5d, %5d], type = %6s, split = %d\n", name.data(), ne[0], ne[1], ftype_str[ftype], split_type);
-                }
-
-                size_t bpe = 0;
-
-                switch (ftype) {
-                    case 0: bpe = ggml_type_size(GGML_TYPE_F32);  break;
-                    case 1: bpe = ggml_type_size(GGML_TYPE_F16);  break;
-                    case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;
-                    case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;
-                    default:
-                            {
-                                fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
-                                return false;
-                            }
-                };
-
-                if (n_dims == 1 || n_parts == 1) {
-                    if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
-                        fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
-                                __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
-                        return false;
-                    }
-
-                    if (part_id == 0) {
-                        if (mm_addr) {
-                            off_t offset = fin.tellg();
-                            tensor->data = (char *) mm_addr + offset;
-                            fin.seekg(ggml_nbytes(tensor), std::ios::cur);
-                        }
-                        else {
-                            fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
-                        }
-                    } else {
-                        fin.seekg(ggml_nbytes(tensor), std::ios::cur);
-                    }
-
-                    total_size += ggml_nbytes(tensor);
-                } else {
-                    if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)/n_parts) {
-                        fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
-                                __func__, name.data(), ggml_nbytes(tensor)/n_parts, nelements*bpe);
-                        return false;
-                    }
-
-                    if (split_type == 0) {
-                        const int np0 = ne[0];
-
-                        const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
-                        assert(row_size == tensor->nb[1]);
-
-                        for (int i1 = 0; i1 < ne[1]; ++i1) {
-                            const size_t offset_row = i1*row_size;
-                            const size_t offset = offset_row + ((part_id*np0)/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
-                            fin.read(reinterpret_cast<char *>(tensor->data) + offset, row_size/n_parts);
-                        }
-                    } else {
-                        const int np1 = ne[1];
-
-                        const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
-
-                        for (int i1 = 0; i1 < ne[1]; ++i1) {
-                            const size_t offset_row = (i1 + part_id*np1)*row_size;
-                            fin.read(reinterpret_cast<char *>(tensor->data) + offset_row, row_size);
-                        }
-                    }
-
-                    total_size += ggml_nbytes(tensor)/n_parts;
-                }
-
-                //fprintf(stderr, "%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
-                model.n_loaded++;
-
-                // progress
-                if (progress_callback) {
-                    float current_file_progress = float(size_t(fin.tellg()) - file_offset) / float(file_size - file_offset);
-                    float current_progress = (float(i) + current_file_progress) / float(n_parts);
-                    progress_callback(current_progress, progress_callback_user_data);
-                }
-                if (model.n_loaded % 8 == 0) {
-                    fprintf(stderr, ".");
-                    fflush(stderr);
-                }
-            }
-
-            fprintf(stderr, " done\n");
+            };
 
 
-            fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, model.n_loaded);
-            if (model.n_loaded == 0) {
-                fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
-            } else if (model.n_loaded != (int) model.tensors.size()) {
-                fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
-                return false;
+            // load the tensor data into memory without copying or reading it
+            size_t offset = fin.tellg();
+            size_t tensor_data_size = ggml_nbytes(tensor);
+            offset = (offset + 31) & -32;
+            tensor->data = mm_addr + offset;
+            fin.seekg(offset + tensor_data_size);
+            total_size += tensor_data_size;
+            model.n_loaded++;
+
+            // progress
+            if (progress_callback) {
+                double current_progress = size_t(fin.tellg()) / double(file_size);
+                progress_callback(current_progress, progress_callback_user_data);
             }
             }
         }
         }
 
 
         fin.close();
         fin.close();
+
+        fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, model.n_loaded);
+        if (model.n_loaded == 0) {
+            fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
+        } else if (model.n_loaded != (int) model.tensors.size()) {
+            fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
+            return false;
+        }
     }
     }
 
 
-    lctx.t_load_us = ggml_time_us() - t_start_us;
+    // loading time will be recalculate after the first eval, so
+    // we take page faults deferred by mmap() into consideration
+    lctx.t_load_us = ggml_time_us() - lctx.t_start_us;
 
 
     if (progress_callback) {
     if (progress_callback) {
         progress_callback(1.0, progress_callback_user_data);
         progress_callback(1.0, progress_callback_user_data);
@@ -1216,7 +1053,7 @@ struct llama_tokenizer {
         size_t offs = 0;
         size_t offs = 0;
         while (offs < text.size()) {
         while (offs < text.size()) {
             llama_sp_symbol sym;
             llama_sp_symbol sym;
-            size_t char_len = std::min(text.size() - offs, utf8_len(text[offs]));
+            size_t char_len = Min(text.size() - offs, utf8_len(text[offs]));
             sym.text = text.c_str() + offs;
             sym.text = text.c_str() + offs;
             sym.n = char_len;
             sym.n = char_len;
             offs += char_len;
             offs += char_len;
@@ -1381,7 +1218,7 @@ static llama_vocab::id llama_sample_top_p_top_k(
 
 
     float maxl = -std::numeric_limits<float>::infinity();
     float maxl = -std::numeric_limits<float>::infinity();
     for (const auto & kv : logits_id) {
     for (const auto & kv : logits_id) {
-        maxl = std::max(maxl, kv.first);
+        maxl = Max(maxl, kv.first);
     }
     }
 
 
     // compute probs for the top k tokens
     // compute probs for the top k tokens
@@ -1475,8 +1312,7 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
             return false;
             return false;
         }
         }
         if (magic != LLAMA_FILE_MAGIC) {
         if (magic != LLAMA_FILE_MAGIC) {
-            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
-            return false;
+            return report_bad_magic(fname_inp.c_str());
         }
         }
 
 
         fout.write((char *) &magic, sizeof(magic));
         fout.write((char *) &magic, sizeof(magic));
@@ -1542,8 +1378,8 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
             fout.write((char *) &len, sizeof(len));
             fout.write((char *) &len, sizeof(len));
 
 
             word.resize(len);
             word.resize(len);
-            finp.read ((char *) word.data(), len);
-            fout.write((char *) word.data(), len);
+            finp.read ((char *) &word[0], len);
+            fout.write((char *) &word[0], len);
 
 
             float score;
             float score;
             finp.read ((char *) &score, sizeof(score));
             finp.read ((char *) &score, sizeof(score));
@@ -1593,6 +1429,13 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
             std::string name(length, 0);
             std::string name(length, 0);
             finp.read (&name[0], length);
             finp.read (&name[0], length);
 
 
+            {
+                // ensure tensor data is aligned
+                uint64_t offset = finp.tellg();
+                offset = (offset + 31) & -32;
+                finp.seekg(offset);
+            }
+
             {
             {
                 static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
                 static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
                 printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ftype_str[ftype]);
                 printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ftype_str[ftype]);
@@ -1648,6 +1491,13 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
             }
             }
             fout.write(&name[0], length);
             fout.write(&name[0], length);
 
 
+            {
+                // ensure tensor data is aligned
+                uint64_t offset = fout.tellp();
+                offset = (offset + 31) & -32;
+                fout.seekp(offset);
+            }
+
             if (quantize) {
             if (quantize) {
                 printf("quantizing .. ");
                 printf("quantizing .. ");
                 work.resize(nelements); // for quantization
                 work.resize(nelements); // for quantization
@@ -1824,7 +1674,11 @@ int llama_eval(
         fprintf(stderr, "%s: failed to eval\n", __func__);
         fprintf(stderr, "%s: failed to eval\n", __func__);
         return 1;
         return 1;
     }
     }
-
+    // get a more accurate load time, upon first eval
+    if (!ctx->has_evaluated_once) {
+        ctx->t_load_us = ggml_time_us() - ctx->t_start_us;
+        ctx->has_evaluated_once = true;
+    }
     return 0;
     return 0;
 }
 }
 
 
@@ -1917,9 +1771,9 @@ llama_token llama_sample_top_p_top_k(
 void llama_print_timings(struct llama_context * ctx) {
 void llama_print_timings(struct llama_context * ctx) {
     const int64_t t_end_us = ggml_time_us();
     const int64_t t_end_us = ggml_time_us();
 
 
-    const int32_t n_sample = std::max(1, ctx->n_sample);
-    const int32_t n_eval   = std::max(1, ctx->n_eval);
-    const int32_t n_p_eval = std::max(1, ctx->n_p_eval);
+    const int32_t n_sample = Max(1, ctx->n_sample);
+    const int32_t n_eval   = Max(1, ctx->n_eval);
+    const int32_t n_p_eval = Max(1, ctx->n_p_eval);
 
 
     fprintf(stderr, "\n");
     fprintf(stderr, "\n");
     fprintf(stderr, "%s:        load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0);
     fprintf(stderr, "%s:        load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0);
@@ -1931,7 +1785,6 @@ void llama_print_timings(struct llama_context * ctx) {
 
 
 void llama_reset_timings(struct llama_context * ctx) {
 void llama_reset_timings(struct llama_context * ctx) {
     ctx->t_start_us = ggml_time_us();
     ctx->t_start_us = ggml_time_us();
-
     ctx->t_sample_us = ctx->n_sample = 0;
     ctx->t_sample_us = ctx->n_sample = 0;
     ctx->t_eval_us   = ctx->n_eval   = 0;
     ctx->t_eval_us   = ctx->n_eval   = 0;
     ctx->t_p_eval_us = ctx->n_p_eval = 0;
     ctx->t_p_eval_us = ctx->n_p_eval = 0;

+ 1 - 1
llama.h

@@ -20,7 +20,7 @@
 #endif
 #endif
 
 
 #define LLAMA_FILE_VERSION 1
 #define LLAMA_FILE_VERSION 1
-#define LLAMA_FILE_MAGIC 0x67676d66 // 'ggmf' in hex
+#define LLAMA_FILE_MAGIC 0x67676a74 // 'ggjt' in hex
 #define LLAMA_FILE_MAGIC_UNVERSIONED 0x67676d6c // pre-versioned files
 #define LLAMA_FILE_MAGIC_UNVERSIONED 0x67676d6c // pre-versioned files
 
 
 #ifdef __cplusplus
 #ifdef __cplusplus

BIN
models/ggml-vocab.bin