Jelajahi Sumber

train : finetune LORA (#2632)

* fix track_max_mem in forward_batch_wo_cache_flash_attn_train

* remove unnecessary Adam(W) optimizer tensors.

reduces optimizer memory overhead from 7*modelsize to 2*modelsize.

additionally allows to optimize models with more than 2^31 parameters by replacing int with int64_t.

bumps training checkpoint file version, but old checkpoints can still be read.
new version with less tensors is saved.

* add gradient clipping to AdamW

* Fix reset of unused g->nodes and g->grads to NULL

* implement gradient checkpointing for training

reduces memory overhead from O(n_layer) to O(sqrt(n_layer))

as explained in readme of https://github.com/cybertronai/gradient-checkpointing

* remove unused compute buffer 3

* add and use function ggml_build_backward_expand to avoid stack overflows with large maximum number of nodes

GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);

* change AdamW decay parameter to work like the torch AdamW decay parameter

It is now relative to Adam learning rate `alpha*sched`.
Before that it was relative to `sched` only.

`alpha` being the maximum learning rate and `sched` being a scaling parameter in [0..1]

* change default AdamW weight decay parameter used in training to 0.1 as used in nanoGPT

* change default AdamW weight decay parameter defined in ggml to 0.0, making Adam default instead of AdamW

btw: the default weight decay parameter for torch.optim.AdamW is 0.01

* bug fixes for cross entropy loss

ggml_cross_entropy_loss: sums where not correctly added in workload of each thread
ggml_cross_entropy_loss_back: simplify backward process, reducing numerical issues

guard usage of exp f16 lookup in cross entropy by #define GGML_CROSS_ENTROPY_EXP_FP16

cross entropy loss is only used once during training, but it is quite sensitive to numerical errors introduced by exp-f16-lookup.
so exp-f16-lookup for cross entropy loss is disabled by default, trading better gradients for very slightly worse runtime performance.

* fix test-grad0 for cross_entropy_loss

the second argument to cross_entropy_loss must sum up to 1 for each row

* fix test-grad0 for soft_max

dont use only sum as aggregation, because sum of softmax is always 1 -> finite differences should not work
instead use sum(log(soft_max()*(1-eps)+eps)); use eps to avoid log(0)

* improve finite differences of test-grad0 by using double instead of float

* change cross_entropy_loss to output average over all rows

this helps keeping the loss and gradients in a sane range

* improve gradient checkpointing

sqrt(n_layers) is only the best checkpoint step when mem size of checkpoints and mem size of layers are equal.
since layers require more memory than the single-tensor-checkpoint we use, the optimal values are compute different:

```
  given: n, u, v
  objective: minimize(a*u+b*v) where a*b=n, a>0, b>0
  b=n/a
  minimize(a*u+v*n/a)
  diff(a*u+v*n/a, a) = u - (v*n/a)/a
  diff(a*u+v*n/a, a) == 0
  u - (v*n/a)/a == 0
  u == v*n/(a*a)
  u*a*a = v*n
  a*a = v*n/u
  a = sqrt(n*v/u)
```

this change results in more checkpoints, requiring less layers to store between checkpoints, overall improving memory usage.

* disable gradient checkpointing debug output

* llama : fix rope usage in train-text-from-scratch after ChatGLM change

* add more training parameters:

--enable-restart N         Only for Adam optimizer. Enable restarts of cos-decay
--disable-restart N        Only for Adam optimizer. Disable restarts of cos-decay
--opt-past N               Number of optimization iterations to track for delta convergence test. Disabled when zero.
--opt-delta N              Maximum delta for delta convergence test. Disabled when <= zero.
--opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero.
--adam-epsf N              AdamW epsilon for convergence test. Disabled when <= zero.
--adam-min-alpha N         Adam minimum learning rate alpha, usually 0.1 * alpha

* replace memcpy with reshape operation so that the graph is not cut at the input

this makes it possible to store other values into the input tensor and then simply recompute the graph without rebuilding it

* remove unused function argument from get_example_targets_batch

* measure and print total training time

* add optimization callback to ggml_opt_resume_g

this callback is called before each iteration with custom data and pointer to learning schedule parameter (only used in Adam(W)).

can be used for dynamic learning schedule and setting input data for batches before each iteration

* use optimization callback in training

allows dynamic learning schedule and different batch data for each iteration without relying on low n_iter and high n_examples parameters

reduces runtime by avoiding restart of optimization function and improves training convergence by providing a different batch for each iteration

* add minimum number of tensor dimensions to apply weight decay (default 2)

this allows to not apply weight decay to bias parameters

* rename training parameter cos-decay-alpha to cos-decay-min and clarify that adam-min-alpha also applies to warmup

* fix increase of model.train_samples and model.train_tokens

now that each optimizer iteration gets its own batch we need to multiply by number of opt iterations

* change sampling parameters for prediction after training to defaults of common.h

and clarify what is context for prediction and what are generated tokens

* tighten abs error bounds for cross_entropy_loss in test-grad0

* add conditional compilation of using F16 exp in flash attention

uncomment `// #define GGML_FLASH_ATTN_EXP_FP16` to enable usage of f16 exp in flash attention

* tighten abs error bounds for flash_attn in test-grad0

* tighten abs error bounds for sqrt in test-grad0

* remove out-commented vectorized code of opt_adam

the vectorized code might be bit faster for low number of parameters, but it had a big memory usage overhead

* ggml : update ggml_rms_norm_back with configurable eps

* llama training : fix ggml_rms_norm_back calls to pass configurable eps

* remove trailing whitespace

* add train function using automatic gradient checkpointing backward pass and allocator

* in train function replace add_inplace by regular add

because using add_inplace seems to result in different gradients

* don't use allocate hash_map on context

because the context has no_alloc=True when using memory allocator resulting in NULL data pointers

* correctly clone reshape and permute operations by also cloning tensor->nb values

* fix variable name and add missing type cast

* terminate recursive tensor cloning when reaching tensor without src tensors

* correctly clone view tensors by setting data pointers

without this the checkpointing would only work when being used together with memory allocator

* fix variable names

* swap arguments to commutative ops to be the same as in `forward_batch_wo_cache_flash_attn`

* add input tensors as checkpoints

so that recursive tensor cloning of gradient checkpointing terminates on input tensors

* fix variable name and add missing boolean negation

* make sure some tensors are not reallocated by inserting new temporary nodes depending on them:

output and parameter gradient tensors need to be available at the end of the graph execution

parameter gradient tensors also need to be available before the graph execution because they are set to zero before each optimizer iteration

checkpoint tensors are allocated all together to reduce memory allocator fragmentation

afterwards, in addition to the temporary nodes, we also need to reset the temporary leafs

* fix ASSERT to work with zero layers

* add training options whether to use allocator and/or unified training function

* integrate unified training function which may use memory allocator

the unified training function also supports arguments whether to use flash attention and/or gradient checkpointing

* format name of cloned tensors with " (clone)" suffix

* set names for tensors in unified train function for easier debugging

* allocate graph on context using ggml_new_graph

* remove handwritten training functions

* remove unused training parameters "use_scratch" and "use_unified"

* remove trailing whitespace

* remove unused train params: mem_compute1_gb & mem_compute2_gb

mem_compute_gb is used for compute when automatic memory allocator is not enabled, otherwise it can be very small to only hold the tensor definitions
mem_compute0_gb is used for automatic memory allocator (as long as measurement of max required size is not implemented)

* remove unused forward_batch function

* add debug asserts in ggml_allocr_alloc to some common pitfalls when using this function directly

* only use ggml_allocr_alloc when tensor has NULL data and is no view

* fix test when to create temporary backward graph

temporary backward graph is only necessary when using checkpointing

* fix memory "leak" in optimizers

each iteration a new cplan with new memory for work data was allocated.
now cplan creation only happens at the start of optimization, with each iteration reusing the cplan and its work data.

* reverse order of for loop in ggml_build_backward_expand to save memory when using gradient checkpointing and allocator

with this loop order gradient checkpointing with allocator on 16 layer model saves 13% memory; 2 layer memory it saves 2% memory.

the computation results are the same

* add API functions to access llama model tensors

* add stub example for finetuning, based on train-text-from-scratch

* move and remove code

* add API functions to access remaining model parameters:

mult, head and rot

* first draft for LORA finetune training

* remove const model and layer arguments in API functions for accessing model tensors

* bug fixes to make finetune compile

automatic allocator does not work yet

* add debug prints for training memory improvements

* fix names of lora tensors

* avoid stack overflow resulting from big ggml_cgraph

replace stack allocation and ggml_build_forward by ggml_new_graph in combination with ggml_build_forward_expand

* replace llama API functions to get model tensors by one function to get model tensor by name

LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);

* remove unused call to not existing llama_get_layer_from_model

* implement ggml_compute_forward_out_prod_q_f32

* remove trailing whitespace

* add lora finetune support on quantized base model tensors

* add ggml_add_cast API function

this function works like ggml_add, but accepts a data type for the resulting tensor.
only supported for quantized src0 input.

* use ggml_add_cast in finetuning

lora-applied weights will now have data type F32, which improves gradients when finetuning quantized base models

* bug fix: actually use result type passed to ggml_add_cast

* make sure base model tensors data cannot be used in viewable operations

memory allocator would try to make lora application inplace on base model tensors.
since those are memory mapped this will result in memory access violations

* fix bug in ggml_out_prod which resulted in wrong n_dims of result tensors

* avoid keeping in memory ALL of the gradients

The problem here stems from ggml_graph_reset. This function is called in the optimization function, before each graph computation, to reset the gradients to zero. This required a unique memory slot for each gradient: allocating memory from a previosly freed memory location might lead to non-zero input gradients.

During ggml_compute_backward the gradients are build stepwise by adding or substracting new values, starting from a OP_NONE tensor which needs to contain zero-values. This requires the graph reset.

To avoid this I now remember in ggml_build_backward_expand the original OP_NONE gradient tensors in a hash table, which is passed to ggml_compute_backward. There instead of using add (or sub or similar) I test whether the existing gradient to be changed is a zero-valued-tensor by looking up its existence in the hash table. When it is such a zero-tensor it will not be modified, but replaced by the value to be added, otherwise the regular add (not inplace, allocator will take care of this) will be used. This way none of those zero-tensor values will be necessary in the final backward graph and more importantly they won't need a unique memory slot, just to make them zero.

* remove trailing whitespace

* remove debug prints and function to compute tensor data hash

* improve optimization iteration prints

* adjust maximal values to support finetuning 3B models

* change default finetune params lora_r and lora_alpha to match the n_rank parameters of 4

* bug fix: make sure finetune input gradient is allocated at begin and kept until end

* remove unnecessary src tensor from ggml_get_rows_back

we don't need data of src[2] for computation, only to setup the correct output shape.
remove dependency on src[2], so that allocator can work more freely.

the computational graph is still completely determined, because the output shape is naturally included.
this is similar to how ggml_reshape does it.

* remove unnecessary src tensor from ggml_repeat & ggml_repeat_back

we don't need data of src[1] for computation, only to setup the correct output shape.
remove dependency on src[1], so that allocator can work more freely.

the computational graph is still completely determined, because the output shape is naturally included

* resolve todo

allocator will only make it inplace when they are of the same type

* mixing multiple LORA adapters is now possible

pass more than one '--lora FNAME' argument to apply more than one LORA.
use '--lora-scaled FNAME S' when you want to specify a user-defined scale for an adapter.

* add option to save finetune output every N iterations

* also save latest finetune output with ITERATION="LATEST" and print where files are saved

saving with LATEST makes it easier to resume training from the latest checkpoint
the string "LATEST" can be configured with command line option "--fn-latest STR"

* update checkpoint train stats before saving via "--save-every"

* add command line option `--rank-wo N` for rank of wo tensor

* update finetune README

* fix dump_non_result_info_yaml to output multiple lora adapters

* bug fix: replace GGML_TYPE_SIZE[t] by ggml_type_size(t)

* replace llama_n_mult by llama_n_ff

* finetune bug fixes to compile with merged in code from master

* remove prediction related code to reduce duplicated code with main

use main instead

* reduce large memory overhead in train-text-from-scratch

all gradients had to be pinned so that graph_reset works correctly.
this is no longer necessary with the changes to ggml_compute_backward introduced in this PR.

* add comment explaining why finetune checkpoints are allocated in one block

* make default value of float member a float literal

* handle rms_norm and rope parameters the same as in train-text-from-scratch

* remove unused code

* remove vocab related code as it is unnecessary

* add LLM_KV_TRAINING_TYPE to train-text-from-scratch checkpoints

so that they can be differentiated from lora finetune checkpoints

* add gguf constants and load/save functions from train-text-from-scratch

* add load & save lora finetune checkpoints via gguf

* add python script to convert old finetune checkpoint files to gguf

* remove old checkpoint save & load code

* remove code to print data checksums which was used to verify correctness of new gguf code

* omit tokenization when training is disabled, only save llama lora adapter

training can be disabled by passing '-n 0' to finetune

* remove trailing whitespace

* update README.md

* implement ggml_compute_forward_repeat_f16

* avoid stack overflow of large cgraphs in test-grad0

* add ggml API functions ggml_unravel_index, ggml_get_i32_nd and its analogs for set and for f32

ggml_get_i32_1d, ggml_set_i32_1d, ggml_get_f32_1d, ggml_set_f32_1d now support non-contiguous tensors.
in case of non-contiguous tensor, the 1d index is unraveled into a multi index using ggml_unravel_index to be passed to '_nd' function equivalent.

this fixes a bug in test-grad0 which happens due to ggml_build_backward not building purely contiguous tensors anymore

* increase test-grad0 context mem size to accommodate for bigger cgraph

* add sanity check to ggml_compute_backward, asserting the correct shape of gradients

* fix ggml_acc_or_set to return tensor of correct shape

* remove unused 'inplace' argument from ggml_compute_backward function

inplace operations to add gradients are no longer created by ggml_compute_backward
use allocator to automatically make inplace operations

* add missing argument 'int i0' to ggml_get_i32_nd & ggml_set_i32_nd header declarations

* fix error message in ggml_allocr_alloc to display actual max_avail

* fix check_gradient

ggml_build_backward_expand was previously replaced by ggml_build_backward, but the assignment of forward graph to backward graph missing

* use tensor->view_src instead of ggml_is_view and get_view_source

* move gradient checkpointing code into ggml, new API function:

// build gradient checkpointing backward graph gb for gf using provided checkpoints
// gb_tmp will contain original backward graph with rewritten backward process nodes,
// but without the second forward pass nodes.
GGML_API void ggml_build_backward_gradient_checkpointing(
        struct ggml_context   * ctx,
        struct ggml_cgraph    * gf,
        struct ggml_cgraph    * gb,
        struct ggml_cgraph    * gb_tmp,
        struct ggml_tensor  * * checkpoints,
        int                     n_checkpoints);

* replace custom data getters and setters by ggml functions

* train-text-from-scratch can train (full finetune) gguf models

just pass the gguf model via `--checkpoint-in FN`.
after this, to continue training, pass the generated checkpoint instead of the original gguf model.

tested with smaller models, bigger models may exceed available memory.
use (LORA) finetune for those.

* remove trailing whitespace

* add option to save train-text-from-scratch output every N iterations

* update README.md

* fix warnings

* fix warnings

* remove finetune option to disable allocator

the allocator should always be used.
by making sure that it is always used it gets easier to implement automatic memory requirements computation

* add tensor checkpoints only when gradient checkpointing is enabled

* initialize opt ggml context if none was provided

* add ggml-alloc API function 'ggml_allocr_max_size' to get max size of alloc

GGML_API size_t ggml_allocr_max_size(struct ggml_allocr * alloc);

* finetune: automatically allocate all memory and changes to command line options

remove '--n_examples N' parameter, as it no longer makes sense to call optimization process multiple times in a loop.
add '--only_write_lora' command line option: will skip tokenization and training, to only write a llama.cpp comptabile LORA adapter.
remove memory buffer related command line options.
improve iteration console output.

* add finetune to Makefile

* update README.md

* print time per iteration and estimate remaining time

* increase measured alloc size by tensor_alignment

ggml_allocr_reset will reduce the given size by up to tensor_alignment-1

* fix README.md

* add some more allocator debug prints

* bug fix, probably solves the 'ggml_allocr_alloc: not enough space in the buffer' issue

* revert last commit

"bug fix, probably solves the 'ggml_allocr_alloc: not enough space in the buffer' issue"

"alloc was freeing an externally allocated tensor, because it calculated the end of allocator memory as alloc->data + alloc->max_size instead of alloc->data + alloc->size."

This is intentional to reduce the risk of freeing external tensors when measuring. Unless max_size is not properly calculated, I don't see why this is an issue.

* remove unnecessary "0x" before "%p" output

* move measurement memory segment to upper region of the address space

* update README.md

* fix printf format warnings

* add missing gguf_free in load_checkpoint_lora_file

* load default rms_norm and rope parameters from base model

* add gradient accumulation

specify number accumulation steps with '--grad-acc N'.
this will simulate a bigger batch size of grad_acc*batch.

* fix tracking of train_samples and train_tokens

* build : fix compile warnings

* ggml : fix L-BFGS linesearch loop

* improve finetune time measurement

fix printf warnings on system where int64_t is (long int).
change time datatypes to double because values get big with long training times.
exclude file saving from time measurement.
converge faster to actual time per iteration by removing very small first duration before first iteration was performed.
fix bug in output of total training time, the reported value was 1000 times to small.

* specify default lora rank with '--lora-r N'

'--lora-r N' will specify default rank for all tensors
'--rank-wq N', etc. will override this default rank for specific tensor types.

* fix gradient accumulation bug where the same batch was used for each microstep

* fix gradient accumulation bug where the same batch was used for each microstep

* support grouped-query-attention in ggml_flash_attn and ggml_flash_attn_back

k and v can now be repeated in q along ne[2]

in forward pass just use modulo to compute k and v indices, like ik2 = iq2 % nek2.

in backard pass this won't work as easy, because multiple threads will compete to accumulate to the same k->grad[:,ik1,ik2,ik3] and v->grad[:,iv1,iv2,iv3].
so we change the parallelization over q rows to be over k rows. this ensures non-overlapping (ik2,ik3) across threads.
in each thread we then iterate over the number of repetitions of k/v in q to compute iq2 as iq2 = ik2 + irep*nek2.

since ne2 is not the same for q,k and v we also change how the gradients are concatenated into the result tensor.
additionally the offsets of gradq, gradk and gradv in the result tensor are now memory aligned.

we also simplify the compute_backward part of flash_attn to use ggml_reshape instead of switching over the number of dimensions.
this needs a small change to ggml_reshape, removing the assertion of second argument to be contiguous.
since only the shape (ne) of the second reshape argument is of relevance, its memory layout (nb) is irrelevant -> it can very well be non-contiguous.

change test-grad0 to also test for repeated k/v in q.

this changes the rng and now results in small gradient differences in softmax. these solely come from using f16 exp table lookup in forward softmax: when temporarily changing softmax to use actual exp function, the reported gradient differences go away. gradient differences coming solely from f16 table lookup are acceptable.
added a note to explain this.

* add llama API functions to get grouped-query-attention n_head parameter 'n_head_kv'.

* fix finetune to support grouped-query-attention (using flash-attention)

note: ggml changes to ggml_out_prod are necessary to support grouped-query-attention without flash-attention.

* support broadcastable a in out_prod(a, b) and backward pass of broadcasting mul_mat(a, b)

* test broadcasting mul_mat backward pass

* decouple random number generator of each operation test

when changing one test the rng of others tests is not influenced anymore

* add comment briefly describing what ggml_repeat_back does

* simplify broadcasting mul_mat backward using ggml_repeat_back

* add cgraph evaluation order member and corresponding enum type

this controls in which order ggml_build_forward visits source nodes.
by default the nodes are visited left to right, i.e. src[0] first.
in some cases it is beneficial for ggml-alloc to visit in a different order.
two possible orders are supported: left-to-right (src[0] first) and right-to-left (src[0] last).

* measure max compute size for each cgraph eval order and use best order

this can bring huge memory savings:
e.g. codellama-34b with n_ctx=64, n_batch=1 goes from 92927.8mb down to 4627.6 MB

* remove unused command line options

* add sample start patterns and options to force new or by default resume last shuffling

* update shuffle rng state on reshuffle

* exclude known zero values from computations in flash_attn_f32 & flash_attn_back_f32

* remove probably unnecessary exception type flags from stringstream

* pass correct max number of tokens to llama_tokenize

* account for possible leading whitespace that will be added by tokenizer
e.g. '\t' will be tokenized by llama spm tokenizer to [29871, 12]

* use unrolled vec_mad in out_prod

y is vec_mad result vec.
x is vec_mad input vec.
v is vec_mad input scalar.

ggml_vec_mad_f32_unroll will internally loop over x and v with same y.

GGML_VEC_MAD_UNROLL is by default defined to 32.

This value is empirical optimized using performance test runs of out-prod in openllama-3b finetune with 256 context length and batch size 1. It gives 23% performance boost for out_prod.

Full measurements of out-prod runtime in ms:
	unroll_xv	unroll_yv
1	67014.643	87826.469
2	77117.552	89077.656
4	72091.311	109121.657
8	61077.543	88678.334
16	56914.67	79514.947
24	59024.595	84350.254
28	55952.446	83368.73
32	51476.658	85177.745
36	55973.792	84659.92
40	55139.616	93844.738
48	60736.392	93330.267
64	99856.878	116994.99

Second column is when unrollying yv instead of xv

* set lora_alpha to value of lora_r if it is not set via command line

otherwise only changing lora_r will change scaling of lora adapter used in prediction

* reshuffle original sample order instead of the previous shuffled order

otherwise resumed reshuffle will not result in same sample order

* block tiling for out-prod inspired by mul-mat

block sizes are empirically optimized

roughly doubles the flops of out-prod

* exclude some more known zero values from computations in flash_attn_f32 & flash_attn_back_f32

* add static keywords

* remove outcommented old code

* update train-text-from-scratch with tokenization, sample selection and shuffling from finetune

* remove lbfgs related train parameters

* move common train functions into common/train.[h|cpp]

* move train state into struct train_state

* move train data saving code into callback to unify code of opt_callback

train_params are still different in finetune and train-text-from-scratch, so it can't yet be moved to train.h|cpp

* move common train params into common/train

* move common opt_callback into common/train

* fix consume_common_train_arg

* save and load head_count_kv in lora checkpoints

* increase train_samples by used_samples instead of number of batches

on batch can contain more than one sample when option "fill_with_next_samples" is used

* fix usage of llama_tokenize

* remove static from process_escape since we need it exposed in header

* fix code formating of long function declarations

* fix condition in load_train_state_gguf

* use die("msg") instead of replace GGML_ASSERT(!"msg") or throw std::runtime_error("msg")

* fix saving and loading of training type

* remove terminating '\0' from tokenization

(llama_tokenize is now passed the string length instead of relying on terminating '\0')

* fix compile warnings

* fix compile warnings

* use new/delete for train_state instead of malloc/free

using malloc may result in seg faults when trying to assign string fields

* assert that sample_count > 0, avoiding division by zero

* fix frand to return value in interval [0,1)

* add train option "--sample-random-offsets"

Use samples beginning at random offsets.
The offset is only applied to the first sample in each batch context window.
Together with "--fill-with-next-samples" this may help for training endless text generation.

For example given a dataset containing samples "abcd", "ABCD", "0123".
With context size of 8 and options "--fill-with-next-samples", "--no-separate-with-eos", "--no-separate-with-bos",
the context windows of batches could only be filled with "abcdABCD", "ABCDabcd", "0123abcd", etc.

With "--sample-random-offsets" it can also be filled with "23abcdAB", "bcd0123A", etc.

* deduplicate code into function

* remove n_rot hparam, as it must always be hparam.n_embd_head()

* align code

* assert correct base model tensor shapes

* move some params from lora hparams into model hparams and load model params from gguf

this equalizes the model definition in finetune and text-from-scratch and removes the need for additional llama api functions to get model parameters

* remove now unnecessary llama API functions to get model params that where added by this PR

* train-text-from-scratch: automatically allocate model tensors, remove option '--mem-model N'

* train-text-from-scratch: automatically allocate opt context

* train-text-from-scratch: automatically allocate input tensors

* train-text-from-scratch: automatically allocate compute memory

* remove unused options and equalize train-text-from-scratch with finetune

* initialize opt->loss_after with zero

* add export-lora program

* remove trailing whitespace

* add export-lora build in Makefile

* remove unused struct tensor_info from export-lora

* add export-lora build dependency to llama

because it depends on common, which depends on llama

* update finetune README.md

* cancel optimization when specified number of epochs is completed

* improve handling of export-lora arguments

print errors and warnings when files could not be read or created

* Fix export-lora.cpp "not enough space in the context's memory pool" (#1)

* Fix export-lora.cpp "not enough space in the context's memory pool"

Without this patch, export-lora would sometimes error with "not enough space in the context's memory pool (needed 656784, available 656800)".

* increase required context size by 5*GGML_MEM_ALIGN instead of plain 16

---------

Co-authored-by: xaedes <xaedes@gmail.com>

* improve handling of not yet supported tensor types

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: meatbag-18a <145869052+meatbag-18a@users.noreply.github.com>
xaedes 2 tahun lalu
induk
melakukan
0e76a8992c

+ 2 - 0
.gitignore

@@ -52,6 +52,8 @@ models-mnt
 /server
 /simple
 /batched
+/export-lora
+/finetune
 /speculative
 /parallel
 /train-text-from-scratch

+ 12 - 3
Makefile

@@ -1,5 +1,5 @@
 # Define the default target now so that it is always the first target
-BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple batched save-load-state server embd-input-test gguf llama-bench baby-llama beam-search speculative parallel tests/test-c.o
+BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple batched save-load-state server embd-input-test gguf llama-bench baby-llama beam-search speculative parallel finetune export-lora tests/test-c.o
 
 # Binaries only useful for tests
 TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama
@@ -500,6 +500,9 @@ console.o: common/console.cpp common/console.h
 grammar-parser.o: common/grammar-parser.cpp common/grammar-parser.h
 	$(CXX) $(CXXFLAGS) -c $< -o $@
 
+train.o: common/train.cpp common/train.h
+	$(CXX) $(CXXFLAGS) -c $< -o $@
+
 libllama.so: llama.o ggml.o $(OBJS)
 	$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
 
@@ -550,7 +553,7 @@ embd-input-test: $(LIB_PRE)embdinput$(DSO_EXT) examples/embd-input/embd-input-te
 gguf: examples/gguf/gguf.cpp ggml.o llama.o $(OBJS)
 	$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
 
-train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp ggml.o llama.o common.o $(OBJS)
+train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp ggml.o llama.o common.o train.o $(OBJS)
 	$(CXX) $(TTFS_CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
 
 convert-llama2c-to-ggml: examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp ggml.o llama.o $(OBJS)
@@ -559,12 +562,18 @@ convert-llama2c-to-ggml: examples/convert-llama2c-to-ggml/convert-llama2c-to-ggm
 llama-bench: examples/llama-bench/llama-bench.cpp build-info.h ggml.o llama.o common.o $(OBJS)
 	$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
 
-baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o common.o $(OBJS)
+baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o common.o train.o $(OBJS)
 	$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
 
 beam-search: examples/beam-search/beam-search.cpp build-info.h ggml.o llama.o common.o $(OBJS)
 	$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
 
+finetune: examples/finetune/finetune.cpp build-info.h ggml.o llama.o common.o train.o $(OBJS)
+	$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
+
+export-lora: examples/export-lora/export-lora.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+	$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
+
 speculative: examples/speculative/speculative.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS)
 	$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
 

+ 2 - 0
common/CMakeLists.txt

@@ -9,6 +9,8 @@ add_library(${TARGET} OBJECT
     console.cpp
     grammar-parser.h
     grammar-parser.cpp
+    train.h
+    train.cpp
     )
 
 if (BUILD_SHARED_LIBS)

+ 37 - 6
common/common.cpp

@@ -78,7 +78,7 @@ int32_t get_num_physical_cores() {
     return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
 }
 
-static void process_escapes(std::string& input) {
+void process_escapes(std::string& input) {
     std::size_t input_len = input.length();
     std::size_t output_idx = 0;
 
@@ -352,7 +352,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 invalid_param = true;
                 break;
             }
-            params.lora_adapter = argv[i];
+            params.lora_adapter.push_back({argv[i], 1.0f});
+            params.use_mmap = false;
+        } else if (arg == "--lora-scaled") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            const char * lora_adapter = argv[i];
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.lora_adapter.push_back({lora_adapter, std::stof(argv[i])});
             params.use_mmap = false;
         } else if (arg == "--lora-base") {
             if (++i >= argc) {
@@ -703,6 +715,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  --verbose-prompt      print prompt before generation\n");
     fprintf(stderr, "  --simple-io           use basic IO for better compatibility in subprocesses and limited consoles\n");
     printf("  --lora FNAME          apply LoRA adapter (implies --no-mmap)\n");
+    printf("  --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
     printf("  --lora-base FNAME     optional model to use as a base for the layers modified by the LoRA adapter\n");
     printf("  -m FNAME, --model FNAME\n");
     printf("                        model path (default: %s)\n", params.model.c_str());
@@ -776,10 +789,15 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
         return std::make_tuple(nullptr, nullptr);
     }
 
-    if (!params.lora_adapter.empty()) {
+    for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
+        const std::string& lora_adapter = std::get<0>(params.lora_adapter[i]);
+        float lora_scale = std::get<1>(params.lora_adapter[i]);
         int err = llama_model_apply_lora_from_file(model,
-                                             params.lora_adapter.c_str(),
-                                             params.lora_base.empty() ? NULL : params.lora_base.c_str(),
+                                             lora_adapter.c_str(),
+                                             lora_scale,
+                                             ((i > 0) || params.lora_base.empty())
+                                                ? NULL
+                                                : params.lora_base.c_str(),
                                              params.n_threads);
         if (err != 0) {
             fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
@@ -1225,7 +1243,20 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
         fprintf(stream, "  %d: %f", lb.first, lb.second);
     }
 
-    fprintf(stream, "lora: %s\n", params.lora_adapter.c_str());
+    fprintf(stream, "lora:\n");
+    for (std::tuple<std::string, float> la : params.lora_adapter) {
+        if (std::get<1>(la) != 1.0f) {
+            continue;
+        }
+        fprintf(stream, "  - %s\n", std::get<0>(la).c_str());
+    }
+    fprintf(stream, "lora_scaled:\n");
+    for (std::tuple<std::string, float> la : params.lora_adapter) {
+        if (std::get<1>(la) == 1.0f) {
+            continue;
+        }
+        fprintf(stream, "  - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la));
+    }
     fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
     fprintf(stream, "low_vram: %s # default: false\n", params.low_vram ? "true" : "false");
     fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);

+ 4 - 2
common/common.h

@@ -85,8 +85,8 @@ struct gpt_params {
     std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
     std::string logdir            = "";  // directory in which to save YAML log files
 
-    std::string lora_adapter = "";  // lora adapter path
-    std::string lora_base    = "";  // base model path for the lora adapter
+    std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
+    std::string lora_base  = "";                              // base model path for the lora adapter
 
     int  ppl_stride        = 0;     // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
     int  ppl_output_type   = 0;     // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
@@ -128,6 +128,8 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
 
 std::string gpt_random_prompt(std::mt19937 & rng);
 
+void process_escapes(std::string& input);
+
 //
 // Model utils
 //

+ 1496 - 0
common/train.cpp

@@ -0,0 +1,1496 @@
+#include "train.h"
+#include "common.h"
+
+#include <random>
+#include <sstream>
+#include <functional>
+
+struct random_normal_distribution {
+    std::mt19937 gen;
+    std::normal_distribution<float> rd;
+    float min;
+    float max;
+};
+
+struct random_uniform_distribution {
+    std::mt19937 gen;
+    std::uniform_real_distribution<float> rd;
+};
+
+struct train_state  * init_train_state() {
+    struct train_state * state = new struct train_state;
+    state->train_its     = 0;
+    state->train_samples = 0;
+    state->train_tokens  = 0;
+    state->train_epochs  = 0;
+    state->shuffle_samples_hash  = 0;
+    state->shuffle_sample_count  = 0;
+    state->shuffle_next_sample   = 0;
+    state->shuffle_rng_state_current = "";
+    state->shuffle_rng_state_next    = "";
+
+    state->opt = new struct ggml_opt_context;
+    state->opt->ctx = NULL;
+    state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
+    state->opt->loss_after = 0.0f;
+
+    return state;
+}
+
+void free_train_state(struct train_state  * state) {
+    delete state->opt;
+    delete state;
+}
+
+struct random_normal_distribution * init_random_normal_distribution(
+    int seed, float mean, float std, float min, float max
+) {
+    struct random_normal_distribution * rnd = (struct random_normal_distribution *) malloc(sizeof(struct random_normal_distribution));
+    rnd->gen = std::mt19937(seed);
+    rnd->rd = std::normal_distribution<float>{mean, std};
+    rnd->min = min;
+    rnd->max = max;
+    return rnd;
+}
+
+struct random_uniform_distribution * init_random_uniform_distribution(int seed, float min, float max) {
+    struct random_uniform_distribution * rnd = (struct random_uniform_distribution *) malloc(sizeof(struct random_uniform_distribution));
+    rnd->gen = std::mt19937(seed);
+    rnd->rd = std::uniform_real_distribution<float>{min, max};
+    return rnd;
+}
+
+void free_random_normal_distribution (struct random_normal_distribution  * rnd) {
+    free(rnd);
+}
+
+void free_random_uniform_distribution(struct random_uniform_distribution * rnd) {
+    free(rnd);
+}
+
+struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) {
+    float scale = 1.0f; // xavier
+    switch (tensor->n_dims) {
+        case 1:
+            scale /= sqrtf((float) tensor->ne[0]);
+            for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
+                float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
+                *dst = scale * frand_normal(rnd);
+            }
+            break;
+        case 2:
+            scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
+            for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
+                for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
+                    float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
+                    *dst = scale * frand_normal(rnd);
+                }
+            }
+            break;
+        case 3:
+            scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
+            for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
+                for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
+                    for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
+                        float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
+                        *dst = scale * frand_normal(rnd);
+                    }
+                }
+            }
+            break;
+        case 4:
+            scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
+            for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
+                for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
+                    for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
+                        for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
+                            float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
+                            *dst = scale * frand_normal(rnd);
+                        }
+                    }
+                }
+            }
+            break;
+        default:
+            die("Unsupported tensor->n_dims");
+    };
+    return tensor;
+}
+
+struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd) {
+    switch (tensor->n_dims) {
+        case 1:
+            for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
+                float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
+                *dst = frand_uniform(rnd);
+            }
+            break;
+        case 2:
+            for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
+                for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
+                    float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
+                    *dst = frand_uniform(rnd);
+                }
+            }
+            break;
+        case 3:
+            for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
+                for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
+                    for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
+                        float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
+                        *dst = frand_uniform(rnd);
+                    }
+                }
+            }
+            break;
+        case 4:
+            for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
+                for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
+                    for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
+                        for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
+                            float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
+                            *dst = frand_uniform(rnd);
+                        }
+                    }
+                }
+            }
+            break;
+        default:
+            die("Unsupported tensor->n_dims");
+    };
+    return tensor;
+}
+
+float frand() {
+    return (float)rand()/((float)(RAND_MAX) + 1.0f);
+}
+
+float frand_normal(struct random_normal_distribution * rnd) {
+    return fclamp(rnd->rd(rnd->gen), rnd->min, rnd->max);
+}
+
+float frand_uniform(struct random_uniform_distribution * rnd) {
+    return rnd->rd(rnd->gen);
+}
+
+int clamp(const int v, const int min, const int max) {
+    return ((v < min) ? (min) : (v > max) ? (max) : v);
+}
+
+float fclamp(const float v, const float min, const float max) {
+    return ((v < min) ? (min) : (v > max) ? (max) : v);
+}
+
+void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) {
+    GGML_ASSERT(tensor->n_dims == 1);
+    GGML_ASSERT(tensor->ne[0] == ne0);
+}
+
+void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1) {
+    GGML_ASSERT(tensor->n_dims == 2);
+    GGML_ASSERT(tensor->ne[0] == ne0);
+    GGML_ASSERT(tensor->ne[1] == ne1);
+}
+
+void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2) {
+    GGML_ASSERT(tensor->n_dims == 3);
+    GGML_ASSERT(tensor->ne[0] == ne0);
+    GGML_ASSERT(tensor->ne[1] == ne1);
+    GGML_ASSERT(tensor->ne[2] == ne2);
+}
+
+void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
+    GGML_ASSERT(tensor->n_dims == 4);
+    GGML_ASSERT(tensor->ne[0] == ne0);
+    GGML_ASSERT(tensor->ne[1] == ne1);
+    GGML_ASSERT(tensor->ne[2] == ne2);
+    GGML_ASSERT(tensor->ne[3] == ne3);
+}
+
+int64_t get_example_targets_batch(
+    struct llama_context * lctx,
+    struct ggml_tensor   * tokens_input,
+    struct ggml_tensor   * target_probs,
+    int64_t                example_id,
+    const size_t         * samples_offs,
+    const size_t         * samples_begin,
+    const size_t         * samples_size,
+          size_t           samples_count,
+    const llama_token    * train_data,
+    size_t                 n_train_data,
+    bool                   separate_with_eos,
+    bool                   separate_with_bos,
+    bool                   fill_with_next_samples,
+    bool                   sample_random_offsets
+) {
+    GGML_ASSERT(samples_count > 0);
+    GGML_ASSERT(tokens_input->n_dims  == 2);
+    GGML_ASSERT(target_probs->n_dims  == 3);
+    int64_t n_vocab  = target_probs->ne[0];
+    int64_t n_tokens = tokens_input->ne[0];
+    int64_t n_batch  = tokens_input->ne[1];
+    GGML_ASSERT(n_vocab  == target_probs->ne[0]);
+    GGML_ASSERT(n_tokens == target_probs->ne[1]);
+    GGML_ASSERT(n_batch  == target_probs->ne[2]);
+
+    int64_t used_samples = 0;
+
+    ggml_set_f32(target_probs, 0.0f);
+    llama_token bos = llama_token_bos(lctx);
+    llama_token eos = llama_token_eos(lctx);
+    // printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
+    for (int k=0; k<n_batch; ++k) {
+        // printf("%s: batch %d\n", __func__, k);
+        size_t sample_idx   = (example_id + used_samples) % samples_count;
+        size_t sample_offs  = sample_random_offsets ? samples_offs[sample_idx] : 0;
+        size_t sample_begin = samples_begin[sample_idx];
+        size_t sample_size  = samples_size[sample_idx];
+        ++used_samples;
+
+        // printf("%s: sample_idx=%zu sample=%zu\n", __func__, sample_idx, sample);
+        GGML_ASSERT(sample_begin+sample_size-1 < n_train_data);
+
+        ggml_set_i32_nd(tokens_input, 0, k, 0, 0, bos);
+        bool sample_separation_eos = !separate_with_eos;
+        bool sample_separation_bos = !separate_with_bos;
+        for (int64_t i=0; i<n_tokens; ++i) {
+            llama_token token = eos;
+            if (sample_offs >= sample_size && fill_with_next_samples) {
+                if (!sample_separation_eos) {
+                    // insert eos token to separate samples
+                    sample_separation_eos = true;
+                } else if (!sample_separation_bos) {
+                    // insert bos token to separate samples
+                    sample_separation_bos = true;
+                    token = bos;
+                } else {
+                    // sample separation is done, continue with next sample
+                    sample_separation_eos = !separate_with_eos;
+                    sample_separation_bos = !separate_with_bos;
+                    sample_offs  = 0;
+                    sample_idx   = (example_id + used_samples) % samples_count;
+                    sample_begin = samples_begin[sample_idx];
+                    sample_size  = samples_size[sample_idx];
+                    ++used_samples;
+                }
+            }
+            // note: no else-if here
+            if (sample_offs < sample_size) {
+                token = clamp(train_data[sample_begin+sample_offs], 0, (llama_token) (n_vocab - 1));
+                ++sample_offs;
+            }
+            ggml_set_f32_nd(target_probs,  token, (int) i, (int) k, 0, +1.0f);
+            if (i+1<n_tokens) {
+                ggml_set_i32_nd(tokens_input, (int) (i + 1), (int) k, 0, 0, token);
+            }
+        }
+    }
+
+    return used_samples;
+}
+
+void mt19937_set_state(std::mt19937& rng, const std::string& rng_state) {
+    std::stringstream s_rng_state;
+    s_rng_state.imbue(std::locale::classic());
+    s_rng_state.exceptions(std::stringstream::failbit);
+    s_rng_state.str(rng_state);
+    s_rng_state >> rng;
+}
+
+std::string mt19937_get_state(const std::mt19937& rng) {
+    std::stringstream s_rng_state;
+    s_rng_state.imbue(std::locale::classic());
+    s_rng_state << rng;
+    return s_rng_state.str();
+}
+
+std::string mt19937_seed_to_state(unsigned seed) {
+    std::mt19937 rng(seed);
+    return mt19937_get_state(rng);
+}
+
+std::string shuffle_samples(
+        const std::string & rng_state,
+        size_t            * shuffled_offs,
+        size_t            * shuffled_begins,
+        size_t            * shuffled_sizes,
+        const size_t      * begins,
+        const size_t      * sizes,
+        size_t              count) {
+    if (count == 0) return rng_state;
+
+    std::mt19937 rng;
+    mt19937_set_state(rng, rng_state);
+
+    // sort indices by random value for each index
+    std::vector<size_t> idcs;
+    {
+        std::vector<unsigned> rnd;
+        idcs.resize(count);
+        rnd.resize(count);
+        for (unsigned i=0; i<count; ++i) {
+            idcs[i] = i;
+            rnd[i]  = rng();
+        }
+
+        std::sort(idcs.begin(), idcs.end(), [&rnd](size_t a, size_t b){
+            // stable sort for reproducibility
+            return (rnd[a] == rnd[b]) ? (a < b) : (rnd[a] < rnd[b]);
+        });
+    }
+
+    // create random offsets
+    for (unsigned i=0; i<count; ++i) {
+        shuffled_offs[i] = (size_t) ((sizes[idcs[i]] - 1) * ((double) rng() / (double) (rng.max()-1)));
+    }
+
+    // reorder begins and sizes by sorted indices
+    for (unsigned i=0; i<count; ++i) {
+        shuffled_begins[i] = begins[idcs[i]];
+    }
+
+    for (unsigned i=0; i<count; ++i) {
+        shuffled_sizes[i] = sizes[idcs[i]];
+    }
+
+    return mt19937_get_state(rng);
+}
+
+size_t hash_combine(size_t h1, size_t h2) {
+    return h1 ^ (h2 << 1);
+}
+
+size_t compute_samples_hash(const char* fn, const size_t* samples_begin, const size_t* samples_size, size_t sample_count) {
+    std::hash<std::string> h_string;
+    std::hash<unsigned long long> h_ull;
+    size_t h = h_string(std::string(fn));
+    h = hash_combine(h, h_ull((unsigned long long) sample_count));
+    for (size_t i=0; i< sample_count; ++i) {
+        h = hash_combine(h, h_ull((unsigned long long) samples_begin[i]));
+        h = hash_combine(h, h_ull((unsigned long long) samples_size[i]));
+    }
+    return h;
+}
+
+std::string replace_str(const char * s, const char * needle, const char * replacement) {
+    std::string str = s;
+    size_t pos = str.find(needle);
+    if (pos != std::string::npos) {
+        str.replace(pos, strlen(needle), replacement);
+    }
+    return str;
+}
+
+void print_duration(double fmillis) {
+    if (fmillis < 1000.0f) {
+        printf("%.1fms", (float) fmillis);
+        return;
+    }
+    const int64_t one_sec  = 1000;
+    const int64_t one_min  = one_sec  * 60;
+    const int64_t one_hour = one_min  * 60;
+    const int64_t one_day  = one_hour * 24;
+
+    int64_t millis  = (int64_t) fmillis;
+    int64_t days    = millis/one_day;
+    int64_t hours   = (millis - days*one_day)/one_hour;
+    int64_t minutes = (millis - days*one_day - hours*one_hour)/one_min;
+    int64_t seconds = (millis - days*one_day - hours*one_hour - minutes*one_min)/one_sec;
+
+    // to print int64_t either cast to (long long int) or use macro PRId64 from <inttypes.h>
+    if (days > 0) {
+        printf("%lldd ", (long long int) days);
+    }
+    printf("%02lld:%02lld:%02lld", (long long int) hours, (long long int) minutes, (long long int) seconds);
+}
+
+float cosine_decay(int64_t step, int64_t decay_steps, float minimum) {
+    if (step > decay_steps) {
+        step = decay_steps;
+    }
+    const float cosine_decay = 0.50f*(1.0f + cosf(3.14159265359f*step/decay_steps));
+    const float decay = (1 - minimum)*cosine_decay + minimum;
+    return decay;
+}
+
+float cosine_decay_restart(int64_t step, int64_t decay_steps, float minimum, float restart_step_mult) {
+    while (step > decay_steps) {
+        step -= decay_steps;
+        decay_steps = (int64_t) (restart_step_mult * decay_steps);
+    }
+    return cosine_decay(step, decay_steps, minimum);
+}
+
+float learning_schedule(
+    int64_t step,
+    int64_t warmup_steps,
+    int64_t cos_decay_steps,
+    float   learning_rate,
+    float   overall_minimum,
+    float   cos_decay_minimum,
+    float   cos_decay_restart_step_mult,
+    bool    enable_restart) {
+
+    float result =
+        (step < warmup_steps)
+            ? (float) step / (float) warmup_steps
+            : enable_restart
+                ? cosine_decay_restart(
+                    step - warmup_steps,
+                    cos_decay_steps,
+                    cos_decay_minimum,
+                    cos_decay_restart_step_mult)
+                : cosine_decay(
+                    step,
+                    cos_decay_steps,
+                    cos_decay_minimum);
+
+    float min = overall_minimum / learning_rate;
+    result = min + result * (1.0f - min);
+    return result;
+}
+
+static bool are_same_layout(struct ggml_tensor * a, struct ggml_tensor * b) {
+    GGML_ASSERT(a != NULL);
+    GGML_ASSERT(b != NULL);
+    GGML_ASSERT(a->type == b->type);
+    GGML_ASSERT(ggml_are_same_shape(a, b));
+    GGML_ASSERT(ggml_is_contiguous(a) && ggml_is_contiguous(b));
+
+    return true;
+}
+
+void copy_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name) {
+    if (dst == NULL) {
+        return;
+    }
+    struct ggml_tensor * t  = ggml_get_tensor(ctx, name);
+    GGML_ASSERT(are_same_layout(dst, t));
+    memcpy(dst->data, t->data, ggml_nbytes(t));
+
+    if (strlen(ggml_get_name(dst)) == 0) {
+        ggml_set_name(dst, name);
+    }
+}
+
+// gguf constants
+static const char * LLM_KV_OPTIMIZER_TYPE = "optimizer.type";
+static const char * LLM_KV_OPTIMIZER_TYPE_ADAM  = "adam";
+static const char * LLM_KV_OPTIMIZER_TYPE_LBFGS = "lbfgs";
+static const char * LLM_KV_OPTIMIZER_FILE_VERSION               = "optimizer.file_version";
+static const char * LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT     = "optimizer.convergence_past_count";
+static const char * LLM_KV_OPTIMIZER_PARAMETER_COUNT            = "optimizer.parameter_count";
+static const char * LLM_KV_OPTIMIZER_ITERATION_COUNT            = "optimizer.iteration_count";
+static const char * LLM_KV_OPTIMIZER_JUST_INITIALIZED           = "optimizer.just_initialized";
+static const char * LLM_KV_OPTIMIZER_ADAM_BEST_LOSS             = "optimizer.adam.best_loss";
+static const char * LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS         = "optimizer.adam.previous_loss";
+static const char * LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT  = "optimizer.adam.no_improvement_count";
+static const char * LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT = "optimizer.lbfgs.approx_hessian_count";
+static const char * LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS            = "optimizer.lbfgs.best_loss";
+static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP     = "optimizer.lbfgs.line_search_step";
+static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J        = "optimizer.lbfgs.line_search_j";
+static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K        = "optimizer.lbfgs.line_search_k";
+static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END      = "optimizer.lbfgs.line_search_end";
+static const char * LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT = "optimizer.lbfgs.no_improvement_count";
+
+static const char * LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS    = "optimizer.adam.first_moments";
+static const char * LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS   = "optimizer.adam.second_moments";
+static const char * LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES = "optimizer.adam.past_loss_values";
+
+static const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS  = "optimizer.lbfgs.current_parameters";
+static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS = "optimizer.lbfgs.previous_parameters";
+static const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS   = "optimizer.lbfgs.current_gradients";
+static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS  = "optimizer.lbfgs.previous_gradients";
+static const char * LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION    = "optimizer.lbfgs.search_direction";
+static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES    = "optimizer.lbfgs.past_loss_values";
+static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA        = "optimizer.lbfgs.memory_alpha";
+static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS           = "optimizer.lbfgs.memory_ys";
+static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S            = "optimizer.lbfgs.memory_s";
+static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y            = "optimizer.lbfgs.memory_y";
+
+static const char * LLM_KV_TRAINING_FILE_VERSION         = "training.file_version";
+static const char * LLM_KV_TRAINING_ITERATION_COUNT      = "training.iteration_count";
+static const char * LLM_KV_TRAINING_SAMPLE_COUNT         = "training.sample_count";
+static const char * LLM_KV_TRAINING_TOKEN_COUNT          = "training.token_count";
+static const char * LLM_KV_TRAINING_EPOCH_COUNT          = "training.epoch_count";
+static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH = "training.shuffle.samples_hash";
+static const char * LLM_KV_TRAINING_SHUFFLE_RNG_STATE    = "training.shuffle.rng_state";
+static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT = "training.shuffle.sample_count";
+static const char * LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE  = "training.shuffle.next_sample";
+
+#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
+{ \
+    const std::string skey(key); \
+    const int kid = gguf_find_key(ctx, skey.c_str()); \
+    if (kid >= 0) { \
+        enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
+        if (ktype != (type)) { \
+            die_fmt("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype)); \
+        } \
+        (dst) = func(ctx, kid); \
+    } else if (req) { \
+        die_fmt("key not found in model: %s", skey.c_str()); \
+    } \
+}
+
+void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt) {
+    // NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
+
+    uint32_t file_version;
+    GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_FILE_VERSION);
+    GGML_ASSERT(file_version == 0);
+
+    GGUF_GET_KEY(fctx, opt->params.past, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT);
+    GGUF_GET_KEY(fctx, opt->iter, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ITERATION_COUNT);
+    GGUF_GET_KEY(fctx, opt->just_initialized, gguf_get_val_bool, GGUF_TYPE_BOOL, true, LLM_KV_OPTIMIZER_JUST_INITIALIZED);
+
+    uint64_t nx;
+    GGUF_GET_KEY(fctx, nx, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_OPTIMIZER_PARAMETER_COUNT);
+    opt->nx = (size_t) nx;
+
+    // don't call ggml_opt_init until optimizer type and optimizer specific parameters are know
+
+    std::string opt_type;
+    GGUF_GET_KEY(fctx, opt_type, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_OPTIMIZER_TYPE);
+    if (opt_type == LLM_KV_OPTIMIZER_TYPE_ADAM) {
+        opt->params.type = GGML_OPT_ADAM;
+
+        GGUF_GET_KEY(fctx, opt->adam.fx_best,          gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS);
+        GGUF_GET_KEY(fctx, opt->adam.fx_prev,          gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS);
+        GGUF_GET_KEY(fctx, opt->adam.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32,  true, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT);
+
+        ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
+
+        copy_tensor_by_name(opt->adam.m,  f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
+        copy_tensor_by_name(opt->adam.v,  f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
+        copy_tensor_by_name(opt->adam.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
+    } else if (opt_type == LLM_KV_OPTIMIZER_TYPE_LBFGS) {
+        opt->params.type = GGML_OPT_LBFGS;
+
+        GGUF_GET_KEY(fctx, opt->params.lbfgs.m,         gguf_get_val_u32, GGUF_TYPE_UINT32,  true, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT);
+        GGUF_GET_KEY(fctx, opt->lbfgs.fx_best,          gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS);
+        GGUF_GET_KEY(fctx, opt->lbfgs.step,             gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP);
+        GGUF_GET_KEY(fctx, opt->lbfgs.j,                gguf_get_val_i32, GGUF_TYPE_INT32,   true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J);
+        GGUF_GET_KEY(fctx, opt->lbfgs.k,                gguf_get_val_i32, GGUF_TYPE_INT32,   true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K);
+        GGUF_GET_KEY(fctx, opt->lbfgs.end,              gguf_get_val_i32, GGUF_TYPE_INT32,   true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END);
+        GGUF_GET_KEY(fctx, opt->lbfgs.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32,  true, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT);
+
+        ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
+
+        copy_tensor_by_name(opt->lbfgs.x,    f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
+        copy_tensor_by_name(opt->lbfgs.xp,   f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
+        copy_tensor_by_name(opt->lbfgs.g,    f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
+        copy_tensor_by_name(opt->lbfgs.gp,   f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
+        copy_tensor_by_name(opt->lbfgs.d,    f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
+        copy_tensor_by_name(opt->lbfgs.pf,   f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
+        copy_tensor_by_name(opt->lbfgs.lmal, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
+        copy_tensor_by_name(opt->lbfgs.lmys, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
+        copy_tensor_by_name(opt->lbfgs.lms,  f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
+        copy_tensor_by_name(opt->lbfgs.lmy,  f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
+    } else {
+        die("unknown optimizer type\n");
+    }
+}
+
+void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt) {
+    gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_FILE_VERSION, 0);
+    gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT, opt->params.past);
+    gguf_set_val_u64(fctx, LLM_KV_OPTIMIZER_PARAMETER_COUNT, (uint64_t) opt->nx);
+    gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ITERATION_COUNT, opt->iter);
+    gguf_set_val_bool(fctx, LLM_KV_OPTIMIZER_JUST_INITIALIZED, opt->just_initialized);
+
+    switch (opt->params.type) {
+        case GGML_OPT_ADAM:
+            {
+                gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_ADAM);
+                gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS,            opt->adam.fx_best);
+                gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS,        opt->adam.fx_prev);
+                gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT, opt->adam.n_no_improvement);
+
+                ggml_set_name(opt->adam.m, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
+                ggml_set_name(opt->adam.v, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
+                if (opt->adam.pf) {
+                    ggml_set_name(opt->adam.pf, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
+                }
+
+                gguf_add_tensor(fctx, opt->adam.m);
+                gguf_add_tensor(fctx, opt->adam.v);
+                if (opt->adam.pf) {
+                    gguf_add_tensor(fctx, opt->adam.pf);
+                }
+            } break;
+        case GGML_OPT_LBFGS:
+            {
+                gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS);
+                gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, opt->params.lbfgs.m);
+                gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS,            opt->lbfgs.fx_best);
+                gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP,     opt->lbfgs.step);
+                gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J,        opt->lbfgs.j);
+                gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K,        opt->lbfgs.k);
+                gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END,      opt->lbfgs.end);
+                gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT, opt->lbfgs.n_no_improvement);
+
+                ggml_set_name(opt->lbfgs.x,    LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
+                ggml_set_name(opt->lbfgs.xp,   LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
+                ggml_set_name(opt->lbfgs.g,    LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
+                ggml_set_name(opt->lbfgs.gp,   LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
+                ggml_set_name(opt->lbfgs.d,    LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
+                if (opt->lbfgs.pf) {
+                    ggml_set_name(opt->lbfgs.pf, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
+                }
+                ggml_set_name(opt->lbfgs.lmal, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
+                ggml_set_name(opt->lbfgs.lmys, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
+                ggml_set_name(opt->lbfgs.lms,  LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
+                ggml_set_name(opt->lbfgs.lmy,  LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
+
+                gguf_add_tensor(fctx, opt->lbfgs.x);
+                gguf_add_tensor(fctx, opt->lbfgs.xp);
+                gguf_add_tensor(fctx, opt->lbfgs.g);
+                gguf_add_tensor(fctx, opt->lbfgs.gp);
+                gguf_add_tensor(fctx, opt->lbfgs.d);
+                if (opt->lbfgs.pf) {
+                    gguf_add_tensor(fctx, opt->lbfgs.pf);
+                }
+                gguf_add_tensor(fctx, opt->lbfgs.lmal);
+                gguf_add_tensor(fctx, opt->lbfgs.lmys);
+                gguf_add_tensor(fctx, opt->lbfgs.lms);
+                gguf_add_tensor(fctx, opt->lbfgs.lmy);
+            } break;
+    }
+}
+
+bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train) {
+    if (gguf_find_key(fctx, LLM_KV_TRAINING_FILE_VERSION) < 0) {
+        return false;
+    }
+
+    uint32_t file_version;
+    GGUF_GET_KEY(fctx, file_version,         gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION);
+    GGML_ASSERT(file_version <= 1);
+
+    if (file_version == 0) {
+
+        GGUF_GET_KEY(fctx, train->train_its,     gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT);
+        GGUF_GET_KEY(fctx, train->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT);
+        GGUF_GET_KEY(fctx, train->train_tokens,  gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT);
+
+    } else if (file_version == 1) {
+
+        GGUF_GET_KEY(fctx, train->train_its,     gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_ITERATION_COUNT);
+        GGUF_GET_KEY(fctx, train->train_samples, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_SAMPLE_COUNT);
+        GGUF_GET_KEY(fctx, train->train_tokens,  gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_TOKEN_COUNT);
+        GGUF_GET_KEY(fctx, train->train_epochs,  gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_EPOCH_COUNT);
+
+        GGUF_GET_KEY(fctx, train->shuffle_samples_hash,      gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH);
+        GGUF_GET_KEY(fctx, train->shuffle_rng_state_current, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_SHUFFLE_RNG_STATE);
+        GGUF_GET_KEY(fctx, train->shuffle_sample_count,      gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT);
+        GGUF_GET_KEY(fctx, train->shuffle_next_sample,       gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE);
+    }
+
+    load_opt_context_gguf(fctx, f_ggml_ctx, train->opt);
+    return true;
+}
+
+void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train) {
+    gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION,    1);
+    gguf_set_val_u64(fctx, LLM_KV_TRAINING_ITERATION_COUNT, train->train_its);
+    gguf_set_val_u64(fctx, LLM_KV_TRAINING_SAMPLE_COUNT,    train->train_samples);
+    gguf_set_val_u64(fctx, LLM_KV_TRAINING_TOKEN_COUNT,     train->train_tokens);
+    gguf_set_val_u64(fctx, LLM_KV_TRAINING_EPOCH_COUNT,     train->train_epochs);
+
+    gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH, (uint64_t) train->shuffle_samples_hash);
+    gguf_set_val_str(fctx, LLM_KV_TRAINING_SHUFFLE_RNG_STATE,    train->shuffle_rng_state_current.c_str());
+    gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT, (uint64_t) train->shuffle_sample_count);
+    gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE,  (uint64_t) train->shuffle_next_sample);
+
+    save_opt_context_gguf(fctx, train->opt);
+}
+
+
+struct llama_file {
+    // use FILE * so we don't have to re-open the file to mmap
+    FILE * fp;
+    size_t size;
+
+    llama_file(const char * fname, const char * mode) {
+        fp = std::fopen(fname, mode);
+        if (fp == NULL) {
+            size = 0;
+        } else {
+            seek(0, SEEK_END);
+            size = tell();
+            seek(0, SEEK_SET);
+        }
+    }
+
+    size_t tell() const {
+#ifdef _WIN32
+        __int64 ret = _ftelli64(fp);
+#else
+        long ret = std::ftell(fp);
+#endif
+        GGML_ASSERT(ret != -1); // this really shouldn't fail
+        return (size_t) ret;
+    }
+
+    void seek(size_t offset, int whence) {
+#ifdef _WIN32
+        int ret = _fseeki64(fp, (__int64) offset, whence);
+#else
+        int ret = std::fseek(fp, (long) offset, whence);
+#endif
+        GGML_ASSERT(ret == 0); // same
+    }
+
+    void read_raw(void * ptr, size_t size) {
+        if (size == 0) {
+            return;
+        }
+        errno = 0;
+        std::size_t ret = std::fread(ptr, size, 1, fp);
+        if (ferror(fp)) {
+            die_fmt("read error: %s", strerror(errno));
+        }
+        if (ret != 1) {
+            die("unexpectedly reached end of file");
+        }
+    }
+
+    std::uint32_t read_u32() {
+        std::uint32_t ret;
+        read_raw(&ret, sizeof(ret));
+        return ret;
+    }
+
+    std::string read_string(std::uint32_t len) {
+        std::vector<char> chars(len);
+        read_raw(chars.data(), len);
+        return std::string(chars.data(), len);
+    }
+
+    void write_raw(const void * ptr, size_t size) {
+        if (size == 0) {
+            return;
+        }
+        errno = 0;
+        size_t ret = std::fwrite(ptr, size, 1, fp);
+        if (ret != 1) {
+            die_fmt("write error: %s", strerror(errno));
+        }
+    }
+
+    void write_u32(std::uint32_t val) {
+        write_raw(&val, sizeof(val));
+    }
+
+    ~llama_file() {
+        if (fp) {
+            std::fclose(fp);
+        }
+    }
+};
+
+static size_t utf8_len(char src) {
+    const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
+    uint8_t highbits = static_cast<uint8_t>(src) >> 4;
+    return lookup[highbits];
+}
+
+// mark each byte with its utf8 unit number.
+// returns the number of utf8 characters.
+// e.g. when bytes == '\x61\xD0\xB0\x62',
+// then utf8_units will become [0,0,1,0]
+// utf8_nunits will become [1,2,2,1] and 3 is returned.
+// bytes where utf8_units is zero, are the begin of an utf8 character.
+static size_t mark_utf8_units(const char* bytes, int * utf8_units, int * utf8_nunits, size_t count) {
+    size_t offs = 0;
+    size_t count_utf8 = 0;
+    while(offs < count) {
+        int len = (int) utf8_len(bytes[offs]);
+        for (int i=0; i<len; ++i) {
+            utf8_units[offs+i]  = i;
+            utf8_nunits[offs+i] = len;
+        }
+        offs += len;
+        ++count_utf8;
+    }
+    return count_utf8;
+}
+
+size_t tokenize_file(
+        struct llama_context     * lctx,
+        const char               * filename,
+        const std::string        & sample_start,
+        bool                       include_sample_start,
+        bool                       overlapping_samples,
+        unsigned                   context_length,
+        std::vector<llama_token> & out_tokens,
+        std::vector<size_t>      & out_samples_begin,
+        std::vector<size_t>      & out_samples_size) {
+    struct llama_file f(filename, "rb");
+
+    if (f.size == 0) {
+        out_tokens.clear();
+        out_samples_begin.clear();
+        out_samples_size.clear();
+        printf("%s: warning: empty or not existing training data file '%s'\n",
+            __func__, filename);
+        return out_tokens.size();
+    }
+
+    // account for possible leading whitespace that will be added by tokenizer
+    // e.g. '\t' will be tokenized by llama spm tokenizer to [29871, 12]
+    const int n_max_tokens_overhead = 1;
+
+    std::vector<char> buf;
+    buf.resize(f.size);
+
+    f.read_raw(buf.data(), f.size);
+
+    std::vector<int> utf8_units;
+    std::vector<int> utf8_nunits;
+    utf8_units.resize(buf.size());
+    utf8_nunits.resize(buf.size());
+    mark_utf8_units(buf.data(), utf8_units.data(), utf8_nunits.data(), buf.size());
+
+    if (sample_start.size() == 0) {
+        // tokenize all data at once
+        out_tokens.resize(buf.size() + n_max_tokens_overhead);
+
+        int n_tokens = llama_tokenize(
+            lctx,
+            buf.data(),
+            (int) buf.size(),
+            out_tokens.data(),
+            (int) out_tokens.size(),
+            false);
+        if (n_tokens < 0) {
+            out_tokens.resize(-n_tokens);
+            n_tokens = llama_tokenize(
+                lctx,
+                buf.data(),
+                (int) buf.size(),
+                out_tokens.data(),
+                (int) out_tokens.size(),
+                false);
+        }
+        if (n_tokens >= 0) {
+            out_tokens.resize(n_tokens);
+        }
+
+        // generate sample starts at all token positions
+        out_samples_begin.clear();
+        out_samples_begin.push_back(0);
+        out_samples_size.push_back(std::min((size_t) context_length, out_tokens.size()));
+        size_t end = (out_tokens.size() >= context_length) ? (out_tokens.size() - context_length) : 0;
+        for (size_t sample_begin = 1; sample_begin < end; ++sample_begin) {
+            out_samples_begin.push_back(sample_begin);
+            out_samples_size.push_back(context_length);
+        }
+    } else {
+        // split data into samples and tokenize each sample
+        std::string data_str(buf.data(), buf.size());
+        out_samples_begin.clear();
+        out_samples_size.clear();
+        out_tokens.clear();
+
+        // find all positions of pattern sample_start
+        size_t sample_begin = data_str.find(sample_start, 0);
+        while (sample_begin != std::string::npos) {
+            out_samples_begin.push_back(sample_begin);
+            const size_t search_start = sample_begin + sample_start.size();
+            sample_begin = data_str.find(sample_start, search_start);
+        }
+        if (out_samples_begin.size() == 0) {
+            printf("%s: warning: sample start pattern '%s' not found. inserting single sample at data begin\n",
+                __func__, sample_start.c_str());
+            out_samples_begin.push_back(0);
+        }
+
+        out_samples_size.resize(out_samples_begin.size(), 0);
+
+        std::vector<char>        buf_sample;
+        std::vector<llama_token> tok_sample;
+
+        const size_t sample_begin_offset = (include_sample_start ? 0 : sample_start.size());
+        size_t found_too_big_sample   = 0;
+        size_t found_too_small_sample = 0;
+        size_t found_empty_sample     = 0;
+        size_t found_min_sample_size  = SIZE_MAX;
+        size_t found_max_sample_size  = 0;
+
+        size_t max_token_text_size = 0;
+        int n_vocab = llama_n_vocab(lctx);
+        for (llama_token token=0; token < n_vocab; ++token) {
+            max_token_text_size = std::max(
+                max_token_text_size,
+                strlen(llama_token_get_text(lctx, token)));
+        }
+
+        // upper bound of context byte length.
+        // strings with this byte length should always tokenize to at least context_length tokens.
+        size_t context_byte_len = max_token_text_size*context_length;
+
+        for (unsigned i=0; i<out_samples_begin.size(); ++i) {
+            // determine sample begin and end from pattern positions
+            size_t sample_begin = out_samples_begin[i] + sample_begin_offset;
+            size_t sample_end   = overlapping_samples
+                                    ? std::min(
+                                        data_str.size(),
+                                        sample_begin + context_byte_len)
+                                    : (i+1 < out_samples_begin.size()
+                                        ? out_samples_begin[i+1]
+                                        : data_str.size());
+            if (sample_end < utf8_units.size() && utf8_units[sample_end] > 0) {
+                // sample end is in the middle of an utf8 character.
+                // advance sample_end to the begin of the next utf8 character.
+                sample_end += utf8_nunits[sample_end] - utf8_units[sample_end];
+            }
+            size_t sample_size = sample_end - sample_begin;
+            if (sample_size == 0) {
+                ++found_empty_sample;
+            }
+
+            if (sample_size > 0) {
+                // llama_tokenize expects zero terminated string,
+                // copy sample into buffer and zero terminate it.
+                buf_sample.resize(sample_size);
+                memcpy(buf_sample.data(), data_str.data() + sample_begin, sample_size);
+
+                // printf("sample: '%s'\n", buf_sample.data());
+
+                // tokenize the sample
+                tok_sample.resize(buf_sample.size() + n_max_tokens_overhead);
+                int n_tokens = llama_tokenize(lctx,
+                    buf_sample.data(),
+                    (int) buf_sample.size(),
+                    tok_sample.data(),
+                    (int) tok_sample.size(),
+                    false);
+                if (n_tokens < 0) {
+                    tok_sample.resize(-n_tokens);
+                    n_tokens = llama_tokenize(lctx,
+                        buf_sample.data(),
+                        (int) buf_sample.size(),
+                        tok_sample.data(),
+                        (int) tok_sample.size(),
+                        false);
+                    GGML_ASSERT(n_tokens >= 0);
+                }
+                GGML_ASSERT(n_tokens <= (int) tok_sample.size());
+
+                if ((size_t) n_tokens > context_length) {
+                    ++found_too_big_sample;
+                } else if ((size_t) n_tokens < context_length) {
+                    ++found_too_small_sample;
+                }
+                found_max_sample_size = std::max(found_max_sample_size, (size_t) n_tokens);
+                found_min_sample_size = std::min(found_min_sample_size, (size_t) n_tokens);
+
+                // write out tokens, start and size of sample
+                // overwrite the string start position with the token start position
+                out_samples_begin[i] = out_tokens.size();
+                out_samples_size[i] = (size_t) n_tokens;
+                out_tokens.insert(out_tokens.end(), tok_sample.begin(), tok_sample.begin() + n_tokens);
+            } else {
+                out_samples_begin[i] = out_tokens.size();
+                out_samples_size[i] = 0;
+            }
+
+        }
+        if (found_too_big_sample > 0) {
+            printf("%s: warning: found %zu samples (max length %zu) that exceed context length of %u. samples will be cut off.\n",
+                __func__, found_too_big_sample, found_max_sample_size, context_length);
+        }
+
+        if (found_too_small_sample > 0) {
+            printf("%s: warning: found %zu samples (min length %zu) that are shorter than context length of %u.\n",
+                __func__, found_too_small_sample, found_min_sample_size, context_length);
+        }
+
+        if (found_empty_sample) {
+            printf("%s: warning: found %zu empty samples.\n",
+                __func__, found_empty_sample);
+        }
+    }
+    printf("%s: total number of samples: %zu\n",
+        __func__, out_samples_begin.size());
+
+    GGML_ASSERT(out_samples_begin.size() == out_samples_size.size());
+
+    return out_tokens.size();
+}
+
+std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration) {
+    std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
+    return replace_str(filename, pattern_it, sit.c_str());
+}
+
+struct train_params_common get_default_train_params_common() {
+    struct train_params_common params;
+    params.fn_train_data     = "shakespeare.txt";
+    params.fn_checkpoint_in  = "checkpoint.gguf";
+    params.fn_checkpoint_out = "checkpoint-ITERATION.gguf";
+    params.pattern_fn_it     = "ITERATION";
+    params.fn_latest         = "LATEST";
+
+    params.print_usage = false;
+
+    params.save_every = 10;
+
+    params.seed       =   -1;
+
+    params.n_ctx      =  128;
+    params.n_threads  =    6;
+    params.n_batch    =    8;
+    params.n_gradient_accumulation = 1;
+    params.n_epochs   = -1;
+
+    params.custom_n_ctx = false;
+
+    params.use_flash              = true;
+    params.use_checkpointing      = true;
+
+    params.sample_start           = "";
+    params.include_sample_start   = false;
+    params.escape                 = false;
+    params.overlapping_samples    = false;
+    params.fill_with_next_samples = false;
+    params.separate_with_eos      = false;
+    params.separate_with_bos      = true;
+    params.sample_random_offsets  = false;
+    params.force_reshuffle        = false;
+
+    params.opt_past               = 0;
+    params.opt_delta              = 1e-5f;
+    params.opt_max_no_improvement = 0;
+
+    params.warmup            =  100;
+    params.cos_decay_steps   = 1000;
+    params.cos_decay_restart = 1.1f;
+    params.cos_decay_min     = 0.1f;
+    params.enable_restart    = false;
+
+    params.adam_n_iter         = 256;
+    params.adam_alpha          = 1e-3f;
+    params.adam_min_alpha      = 0;
+    params.adam_decay          = 1e-1f;
+    params.adam_decay_min_ndim = 2;
+    params.adam_beta1          = 0.9f;
+    params.adam_beta2          = 0.999f;
+    params.adam_gclip          = 1.0f;
+    params.adam_eps_f          = 0.0f;
+    return params;
+}
+
+void print_common_train_usage(int /*argc*/, char ** /*argv*/, const struct train_params_common * params) {
+    // fprintf(stderr, "usage: %s [options]\n", argv[0]);
+    // fprintf(stderr, "\n");
+    // fprintf(stderr, "options:\n");
+    // fprintf(stderr, "  -h, --help                 show this help message and exit\n");
+    fprintf(stderr, "  --train-data FNAME         path from which to load training data (default '%s')\n", params->fn_train_data);
+    fprintf(stderr, "  --checkpoint-in FNAME      path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
+    fprintf(stderr, "  --checkpoint-out FNAME     path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
+    fprintf(stderr, "  --pattern-fn-it STR        pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it);
+    fprintf(stderr, "  --fn-latest STR            string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest);
+    fprintf(stderr, "  --save-every N             save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every);
+    fprintf(stderr, "  -s SEED, --seed SEED       RNG seed (default: -1, use random seed for -1)\n");
+    fprintf(stderr, "  -c N, --ctx N              Context size used during training (default %d)\n", params->n_ctx);
+    fprintf(stderr, "  -t N, --threads N          Number of threads (default %d)\n", params->n_threads);
+    fprintf(stderr, "  -b N, --batch N            Parallel batch size (default %d)\n", params->n_batch);
+    fprintf(stderr, "  --grad-acc N               Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation);
+    fprintf(stderr, "  --sample-start STR         Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str());
+    fprintf(stderr, "  --include-sample-start     Include the sample start in the samples. (default off)\n");
+    fprintf(stderr, "  --escape                   process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
+    fprintf(stderr, "  --overlapping-samples      Samples my overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n");
+    fprintf(stderr, "  --fill-with-next-samples   Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n");
+    fprintf(stderr, "  --separate-with-eos        When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : "");
+    fprintf(stderr, "  --separate-with-bos        When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : "");
+    fprintf(stderr, "  --no-separate-with-eos     When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : "");
+    fprintf(stderr, "  --no-separate-with-bos     When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : "");
+    fprintf(stderr, "  --sample-random-offsets    Use samples beginning at random offsets. Together with fill-with-next-samples this may help for training endless text generation.%s\n", params->sample_random_offsets ? " (default)" : "");
+    fprintf(stderr, "  --force-reshuffle          Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n");
+    fprintf(stderr, "  --no-flash                 Don't use flash attention \n");
+    fprintf(stderr, "  --use-flash                Use flash attention (default)\n");
+    fprintf(stderr, "  --no-checkpointing         Don't use gradient checkpointing\n");
+    fprintf(stderr, "  --use-checkpointing        Use gradient checkpointing (default)\n");
+    fprintf(stderr, "  --warmup N                 Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup);
+    fprintf(stderr, "  --cos-decay-steps N        Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
+    fprintf(stderr, "  --cos-decay-restart N      Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
+    fprintf(stderr, "  --cos-decay-min N          Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min);
+    fprintf(stderr, "  --enable-restart N         Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : "");
+    fprintf(stderr, "  --disable-restart N        Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : "");
+    fprintf(stderr, "  --opt-past N               Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
+    fprintf(stderr, "  --opt-delta N              Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
+    fprintf(stderr, "  --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
+    fprintf(stderr, "  --epochs N                 Maximum number epochs to process. (default %d)\n", params->n_epochs);
+    fprintf(stderr, "  --adam-iter N              Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
+    fprintf(stderr, "  --adam-alpha N             Adam learning rate alpha (default %f)\n", params->adam_alpha);
+    fprintf(stderr, "  --adam-min-alpha N         Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha);
+    fprintf(stderr, "  --adam-decay N             AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
+    fprintf(stderr, "  --adam-decay-min-ndim N    Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim);
+    fprintf(stderr, "  --adam-beta1 N             AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
+    fprintf(stderr, "  --adam-beta2 N             AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
+    fprintf(stderr, "  --adam-gclip N             AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
+    fprintf(stderr, "  --adam-epsf N              AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
+    fprintf(stderr, "\n");
+}
+
+bool consume_common_train_arg(
+    int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param
+) {
+    int& i = *idx;
+    std::string arg = argv[i];
+    const std::string arg_prefix = "--";
+    if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
+        std::replace(arg.begin(), arg.end(), '_', '-');
+    }
+    if (arg == "--train-data") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->fn_train_data = argv[i];
+    } else if (arg == "--checkpoint-in") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->fn_checkpoint_in = argv[i];
+    } else if (arg == "--checkpoint-out") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->fn_checkpoint_out = argv[i];
+    } else if (arg == "--pattern-fn-it") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->pattern_fn_it = argv[i];
+    } else if (arg == "--fn-latest") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->fn_latest = argv[i];
+    } else if (arg == "--save-every") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->save_every = std::stoi(argv[i]);
+    } else if (arg == "-s" || arg == "--seed") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->seed = std::stoi(argv[i]);
+    } else if (arg == "-c" || arg == "--ctx") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->n_ctx = std::stoi(argv[i]);
+        params->custom_n_ctx = true;
+    } else if (arg == "-t" || arg == "--threads") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->n_threads = std::stoi(argv[i]);
+    } else if (arg == "-b" || arg == "--batch") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->n_batch = std::stoi(argv[i]);
+    } else if (arg == "--grad-acc") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->n_gradient_accumulation = std::max(1, std::stoi(argv[i]));
+    } else if (arg == "--sample-start") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->sample_start = std::string(argv[i]);
+    } else if (arg == "--escape") {
+        params->escape = true;
+    } else if (arg == "--include-sample-start") {
+        params->include_sample_start = true;
+    } else if (arg == "--overlapping-samples") {
+        params->overlapping_samples = true;
+    } else if (arg == "--fill-with-next-samples") {
+        params->fill_with_next_samples = true;
+    } else if (arg == "--separate-with-eos") {
+        params->separate_with_eos = true;
+    } else if (arg == "--separate-with-bos") {
+        params->separate_with_bos = true;
+    } else if (arg == "--no-separate-with-eos") {
+        params->separate_with_eos = false;
+    } else if (arg == "--no-separate-with-bos") {
+        params->separate_with_bos = false;
+    } else if (arg == "--sample-random-offsets") {
+        params->sample_random_offsets = true;
+    } else if (arg == "--force-reshuffle") {
+        params->force_reshuffle = true;
+    } else if (arg == "--no-flash") {
+        params->use_flash = false;
+    } else if (arg == "--use-flash") {
+        params->use_flash = true;
+    } else if (arg == "--no-checkpointing") {
+        params->use_checkpointing = false;
+    } else if (arg == "--use-checkpointing") {
+        params->use_checkpointing = true;
+    } else if (arg == "--warmup") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->warmup = std::stoi(argv[i]);
+    } else if (arg == "--cos-decay-steps") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->cos_decay_steps = std::stoi(argv[i]);
+    } else if (arg == "--cos-decay-restart") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->cos_decay_restart = std::stof(argv[i]);
+    } else if (arg == "--cos-decay-min") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->cos_decay_min = std::stof(argv[i]);
+    } else if (arg == "--enable-restart") {
+        params->enable_restart = true;
+    } else if (arg == "--disable-restart") {
+        params->enable_restart = false;
+    } else if (arg == "--opt-past") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->opt_past = std::stoi(argv[i]);
+    } else if (arg == "--opt-delta") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->opt_delta = std::stof(argv[i]);
+    } else if (arg == "--opt-max-no-improvement") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->opt_max_no_improvement = std::stoi(argv[i]);
+    } else if (arg == "--adam-epsf") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->adam_eps_f = std::stof(argv[i]);
+    } else if (arg == "--epochs") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->n_epochs = std::stoi(argv[i]);
+    } else if (arg == "--adam-iter") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->adam_n_iter = std::stoi(argv[i]);
+    } else if (arg == "--adam-alpha") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->adam_alpha = std::stof(argv[i]);
+    } else if (arg == "--adam-min-alpha") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->adam_min_alpha = std::stof(argv[i]);
+    } else if (arg == "--adam-decay") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->adam_decay = std::stof(argv[i]);
+    } else if (arg == "--adam-decay-min-ndim") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->adam_decay_min_ndim = std::stoi(argv[i]);
+    } else if (arg == "--adam-beta1") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->adam_beta1 = std::stof(argv[i]);
+    } else if (arg == "--adam-beta2") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->adam_beta2 = std::stof(argv[i]);
+    } else if (arg == "--adam-gclip") {
+        if (++i >= argc) {
+            *invalid_param = true;
+            return true;
+        }
+        params->adam_gclip = std::stof(argv[i]);
+    } else if (arg == "-h" || arg == "--help") {
+        params->print_usage = true;
+        return true;
+    } else {
+        return false;
+    }
+    return true;
+}
+
+void finish_processing_train_args(struct train_params_common * params) {
+    if (params->escape) {
+        process_escapes(params->sample_start);
+    }
+}
+
+void train_opt_callback(void * vdata, int accum_step, float * sched, bool * cancel) {
+    struct train_opt_callback_data * data   = (struct train_opt_callback_data *) vdata;
+    struct train_params_common     * params = data->params;
+    struct train_state             * train  = data->train;
+    struct ggml_opt_context        * opt    = train->opt;
+    int n_batch = params->n_batch;
+    int n_ctx = params->n_ctx;
+
+    if (accum_step == 0) {
+        // time measurement
+        int64_t now = ggml_time_ms();
+        if (now > data->last_time && opt->iter > data->first_iter) {
+            double dt = (double) (now - data->last_time);
+            if (data->millis_per_iter == 0.0) {
+                data->millis_per_iter = dt;
+            } else {
+                const double gain = 0.7;
+                data->millis_per_iter = data->millis_per_iter*(1.0-gain) + dt*gain;
+            }
+        }
+
+        double remaining_millis = 0.0;
+        if (data->millis_per_iter > 0.0) {
+            const int n_iter = params->adam_n_iter;
+            const int done_iter = opt->iter - data->first_iter;
+            const int remaining_iter = n_iter - done_iter;
+            remaining_millis = remaining_iter * data->millis_per_iter;
+        }
+
+        // file saving
+        const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
+        if (save_now) {
+            int new_iters = opt->iter - data->last_save_iter;
+            train->train_its    += new_iters;
+            train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx;
+
+            if (data->save_cb) {
+                data->save_cb(data->save_data, train);
+            }
+
+            data->last_save_iter = opt->iter;
+        }
+
+        // exclude file saving from time measurement, by measuring last_time after saving
+        data->last_time = ggml_time_ms();
+
+        *sched = learning_schedule(
+            opt->iter,
+            params->warmup,
+            params->cos_decay_steps,
+            params->adam_alpha,
+            params->adam_min_alpha,
+            params->cos_decay_min,
+            params->cos_decay_restart,
+            params->enable_restart);
+
+        int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f);
+        if (impr_plot > 0) impr_plot = 0;
+        if (std::isnan(opt->loss_before) || std::isnan(opt->loss_before)) impr_plot = 0;
+        printf("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f",
+            __func__, opt->iter, std::min(1+train->shuffle_next_sample, train->shuffle_sample_count), train->shuffle_sample_count,
+            *sched, opt->loss_after);
+
+
+        if (data->millis_per_iter > 0) {
+            printf(" dt=");
+            print_duration(data->millis_per_iter);
+            printf(" eta=");
+            print_duration(remaining_millis);
+        }
+
+        float improvement = opt->loss_before - opt->loss_after;
+        const float plot_scale = 10.0f;
+        int bar_len = (int)(1 + improvement*plot_scale + 0.5);
+        printf(" |");
+        for (int i=0; i<bar_len; ++i) {
+            printf("-");
+        }
+        printf(">");
+        printf("\n");
+    }
+
+    int64_t used_samples = get_example_targets_batch(
+        data->lctx,
+        data->tokens_input,
+        data->target_probs,
+        train->shuffle_next_sample,
+        data->shuffled_samples_offs,
+        data->shuffled_samples_begin,
+        data->shuffled_samples_size,
+        data->samples_count,
+        data->tokens_data,
+        data->tokens_size,
+        params->separate_with_eos,
+        params->separate_with_bos,
+        params->fill_with_next_samples,
+        params->sample_random_offsets);
+
+    train->train_samples += used_samples;
+    train->shuffle_next_sample += used_samples;
+
+    if (train->shuffle_next_sample >= train->shuffle_sample_count) {
+        ++train->train_epochs;
+        printf("%s: reshuffle samples. completed epochs: %llu\n", __func__, (long long unsigned) train->train_epochs);
+        // note: we may have used some samples from the current shuffling more than once
+        train->shuffle_rng_state_current = train->shuffle_rng_state_next;
+        train->shuffle_rng_state_next = shuffle_samples(
+            train->shuffle_rng_state_current,
+            data->shuffled_samples_offs,
+            data->shuffled_samples_begin,
+            data->shuffled_samples_size,
+            data->samples_begin,
+            data->samples_size,
+            data->samples_count);
+        train->shuffle_next_sample = 0;
+    }
+
+    const bool last_epoch_reached = (params->n_epochs > 0 && (int64_t) train->train_epochs - data->first_epoch >= params->n_epochs);
+    if (last_epoch_reached) {
+        // allow optimization iteration at last epoch to be completed before canceling
+        if (data->iter_at_last_epoch < 0) {
+            data->iter_at_last_epoch = opt->iter;
+        } else if (opt->iter > data->iter_at_last_epoch) {
+            *cancel = true;
+        }
+    }
+}

+ 230 - 0
common/train.h

@@ -0,0 +1,230 @@
+// Various helper functions and utilities for training
+
+#pragma once
+
+#include <string>
+#include <random>
+#include <vector>
+
+#include "ggml.h"
+#include "llama.h"
+
+typedef std::string mt19937_state;
+
+struct train_state {
+    struct ggml_opt_context * opt;
+
+    uint64_t train_its;
+    uint64_t train_samples;
+    uint64_t train_tokens;
+    uint64_t train_epochs;
+
+    size_t        shuffle_samples_hash; // fn, sample_count, *zip(sample_begins, sample_sizes)
+    mt19937_state shuffle_rng_state_current;
+    mt19937_state shuffle_rng_state_next;
+    size_t        shuffle_sample_count;
+    size_t        shuffle_next_sample;
+};
+
+struct train_params_common {
+    const char * fn_train_data;
+    const char * fn_checkpoint_in;
+    const char * fn_checkpoint_out;
+    const char * pattern_fn_it;
+    const char * fn_latest;
+
+    bool print_usage;
+
+    int save_every;
+
+    uint32_t seed;
+
+    int n_ctx;
+    int n_threads;
+    int n_batch;
+    int n_gradient_accumulation;
+    int n_epochs;
+
+    bool custom_n_ctx;
+
+    bool use_flash;
+    bool use_checkpointing;
+
+    std::string sample_start;
+    bool include_sample_start;
+    bool escape;
+    bool overlapping_samples;
+    bool fill_with_next_samples;
+    bool separate_with_eos;
+    bool separate_with_bos;
+    bool sample_random_offsets;
+
+    bool force_reshuffle;
+
+    int   warmup;
+    int   cos_decay_steps;
+    float cos_decay_restart;
+    float cos_decay_min;
+    bool  enable_restart;
+
+    int   opt_past;
+    float opt_delta;
+    int   opt_max_no_improvement;
+
+    int   adam_n_iter;
+    float adam_alpha;
+    float adam_min_alpha;
+    float adam_decay;
+    int   adam_decay_min_ndim;
+    float adam_beta1;
+    float adam_beta2;
+    float adam_gclip;
+    float adam_eps_f;
+};
+
+typedef void (*save_train_files_callback)(void * data, struct train_state * train);
+
+struct train_opt_callback_data {
+    struct train_params_common * params;
+    struct train_state         * train;
+    save_train_files_callback    save_cb;
+    void                       * save_data;
+    struct llama_context       * lctx;
+    int                          last_save_iter;
+    llama_token                * tokens_data;
+    size_t                       tokens_size;
+    size_t                     * samples_begin;
+    size_t                     * samples_size;
+    size_t                     * shuffled_samples_offs;
+    size_t                     * shuffled_samples_begin;
+    size_t                     * shuffled_samples_size;
+    size_t                       samples_count;
+    struct ggml_tensor         * tokens_input;
+    struct ggml_tensor         * target_probs;
+    int                          first_iter;
+    int                          first_epoch;
+    int                          iter_at_last_epoch;
+    int64_t                      last_time;
+    double                       millis_per_iter;
+};
+
+struct train_state * init_train_state();
+void free_train_state(struct train_state  * state);
+
+struct train_params_common get_default_train_params_common();
+void print_common_train_usage(int /*argc*/, char ** argv, const struct train_params_common * params);
+
+bool consume_common_train_arg(int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param);
+void finish_processing_train_args(struct train_params_common * params);
+
+struct random_normal_distribution;
+struct random_uniform_distribution;
+
+struct random_normal_distribution  * init_random_normal_distribution (int seed, float mean, float std, float min, float max);
+struct random_uniform_distribution * init_random_uniform_distribution(int seed, float min, float max);
+
+void free_random_normal_distribution (struct random_normal_distribution  * rnd);
+void free_random_uniform_distribution(struct random_uniform_distribution * rnd);
+
+struct ggml_tensor * randomize_tensor_normal (struct ggml_tensor * tensor, struct random_normal_distribution * rnd);
+struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd);
+
+// generate random float in interval [0,1)
+float frand();
+float frand_normal (struct random_normal_distribution * rnd);
+float frand_uniform(struct random_uniform_distribution * rnd);
+
+int   clamp (const int v, const int min, const int max);
+float fclamp(const float v, const float min, const float max);
+
+void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0);
+void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1);
+void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2);
+void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3);
+
+size_t tokenize_file(
+        struct llama_context     * lctx,
+        const char               * filename,
+        const std::string        & sample_start,
+        bool                       include_sample_start,
+        bool                       overlapping_samples,
+        unsigned                   context_length,
+        std::vector<llama_token> & out_tokens,
+        std::vector<size_t>      & out_samples_begin,
+        std::vector<size_t>      & out_samples_size);
+
+int64_t get_example_targets_batch(
+        struct llama_context * lctx,
+        struct ggml_tensor   * tokens_input,
+        struct ggml_tensor   * target_probs,
+        int64_t                example_id,
+        const size_t         * samples_offs,
+        const size_t         * samples_begin,
+        const size_t         * samples_size,
+              size_t           samples_count,
+        const llama_token    * train_data,
+        size_t                 n_train_data,
+        bool                   separate_with_eos,
+        bool                   separate_with_bos,
+        bool                   fill_with_next_samples,
+        bool                   sample_random_offsets);
+
+
+void          mt19937_set_state(std::mt19937& rng, const mt19937_state& rng_state);
+mt19937_state mt19937_get_state(const std::mt19937& rng);
+mt19937_state mt19937_seed_to_state(unsigned seed);
+
+mt19937_state shuffle_samples(
+        const mt19937_state & rng_state,
+        size_t              * shuffled_offs,
+        size_t              * shuffled_begins,
+        size_t              * shuffled_sizes,
+        const size_t        * begins,
+        const size_t        * sizes,
+        size_t                count);
+
+size_t hash_combine(size_t h1, size_t h2);
+
+size_t compute_samples_hash(
+    const char* fn,
+    const size_t* samples_begin,
+    const size_t* samples_size,
+    size_t sample_count);
+
+
+std::string replace_str(const char * s, const char * needle, const char * replacement);
+
+void print_duration(double milliseconds);
+
+float cosine_decay(
+    int64_t step,
+    int64_t decay_steps,
+    float   minimum);
+
+float cosine_decay_restart(
+    int64_t step,
+    int64_t decay_steps,
+    float   minimum,
+    float   restart_step_mult);
+
+float learning_schedule(
+    int64_t step,
+    int64_t warmup_steps,
+    int64_t decay_steps,
+    float   learning_rate,
+    float   overall_minimum,
+    float   cos_decay_minimum,
+    float   cos_decay_restart_step_mult,
+    bool    enable_restart);
+
+void copy_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name);
+
+void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt);
+void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt);
+
+bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train);
+void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train);
+
+std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration);
+
+void train_opt_callback(void * vdata, int accum_step, float * sched, bool * cancel);

+ 2 - 0
examples/CMakeLists.txt

@@ -21,6 +21,7 @@ else()
     add_subdirectory(benchmark)
     add_subdirectory(baby-llama)
     add_subdirectory(train-text-from-scratch)
+    add_subdirectory(finetune)
     add_subdirectory(convert-llama2c-to-ggml)
     add_subdirectory(simple)
     add_subdirectory(batched)
@@ -35,4 +36,5 @@ else()
     if (LLAMA_BUILD_SERVER)
         add_subdirectory(server)
     endif()
+    add_subdirectory(export-lora)
 endif()

+ 41 - 135
examples/baby-llama/baby-llama.cpp

@@ -1,4 +1,5 @@
 #include "ggml.h"
+#include "train.h"
 #include <vector>
 #include <cassert>
 #include <random>
@@ -14,31 +15,6 @@ constexpr float rms_norm_eps = LLAMA_DEFAULT_RMS_EPS;
 constexpr float rms_norm_eps = 5e-6f;
 #endif
 
-static float frand() {
-    return (float)rand()/(float)RAND_MAX;
-}
-
-struct random_normal_distribution {
-    std::mt19937 gen;
-    std::normal_distribution<float> nd;
-    float min;
-    float max;
-};
-
-static void init_random_normal_distribution(
-    struct random_normal_distribution * rnd, int seed, float mean, float std, float min, float max
-) {
-    rnd->gen = std::mt19937(seed);
-    rnd->nd = std::normal_distribution<float>{mean, std};
-    rnd->min = min;
-    rnd->max = max;
-}
-
-static float frand_normal(struct random_normal_distribution * rnd) {
-    const float r = rnd->nd(rnd->gen);
-    return ((r < rnd->min) ? (rnd->min) : (r > rnd->max) ? (rnd->max) : r);
-}
-
 static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
     struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
 
@@ -93,54 +69,6 @@ static struct ggml_tensor * randomize_tensor(
     return tensor;
 }
 
-static struct ggml_tensor * randomize_tensor_normal(
-    struct ggml_tensor * tensor, int ndims, const int64_t ne[], struct random_normal_distribution * rnd
-) {
-    float scale = 1.0; // xavier
-    switch (ndims) {
-        case 1:
-            scale /= sqrtf(ne[0]);
-            for (int i0 = 0; i0 < ne[0]; i0++) {
-                ((float *)tensor->data)[i0] = scale * frand_normal(rnd);
-            }
-            break;
-        case 2:
-            scale /= sqrtf(ne[0]+ne[1]);
-            for (int i1 = 0; i1 < ne[1]; i1++) {
-                for (int i0 = 0; i0 < ne[0]; i0++) {
-                    ((float *)tensor->data)[i1*ne[0] + i0] = scale * frand_normal(rnd);
-                }
-            }
-            break;
-        case 3:
-            scale /= sqrtf(ne[0]+ne[1]);
-            for (int i2 = 0; i2 < ne[2]; i2++) {
-                for (int i1 = 0; i1 < ne[1]; i1++) {
-                    for (int i0 = 0; i0 < ne[0]; i0++) {
-                        ((float *)tensor->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = scale * frand_normal(rnd);
-                    }
-                }
-            }
-            break;
-        case 4:
-            scale /= sqrtf(ne[0]+ne[1]);
-            for (int i3 = 0; i3 < ne[3]; i3++) {
-                for (int i2 = 0; i2 < ne[2]; i2++) {
-                    for (int i1 = 0; i1 < ne[1]; i1++) {
-                        for (int i0 = 0; i0 < ne[0]; i0++) {
-                            ((float *)tensor->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = scale * frand_normal(rnd);
-                        }
-                    }
-                }
-            }
-            break;
-        default:
-            assert(false);
-    };
-
-    return tensor;
-}
-
 struct llama_hparams {
     uint32_t n_vocab = 32000;
     uint32_t n_ctx   = 512;   // this is provided as user input?
@@ -398,27 +326,29 @@ static void randomize_model(struct llama_model * model, int seed, float mean, fl
 
     const uint32_t n_layer = hparams.n_layer;
 
-    struct random_normal_distribution rnd;
-    init_random_normal_distribution(&rnd, seed, mean, std, min, max);
-    randomize_tensor_normal(model->tok_embeddings, model->tok_embeddings->n_dims, model->tok_embeddings->ne, &rnd);
-    randomize_tensor_normal(model->norm,           model->norm->n_dims,           model->norm->ne,           &rnd);
-    randomize_tensor_normal(model->output,         model->output->n_dims,         model->output->ne,         &rnd);
+    struct random_normal_distribution * rnd = init_random_normal_distribution(seed, mean, std, min, max);
+
+    randomize_tensor_normal(model->tok_embeddings , rnd);
+    randomize_tensor_normal(model->norm           , rnd);
+    randomize_tensor_normal(model->output         , rnd);
 
     for (uint32_t i = 0; i < n_layer; ++i) {
         auto & layer = model->layers[i];
-        randomize_tensor_normal(layer.attention_norm, layer.attention_norm->n_dims, layer.attention_norm->ne, &rnd);
+        randomize_tensor_normal(layer.attention_norm, rnd);
 
-        randomize_tensor_normal(layer.wq, layer.wq->n_dims, layer.wq->ne, &rnd);
-        randomize_tensor_normal(layer.wk, layer.wk->n_dims, layer.wk->ne, &rnd);
-        randomize_tensor_normal(layer.wv, layer.wv->n_dims, layer.wv->ne, &rnd);
-        randomize_tensor_normal(layer.wo, layer.wo->n_dims, layer.wo->ne, &rnd);
+        randomize_tensor_normal(layer.wq, rnd);
+        randomize_tensor_normal(layer.wk, rnd);
+        randomize_tensor_normal(layer.wv, rnd);
+        randomize_tensor_normal(layer.wo, rnd);
 
-        randomize_tensor_normal(layer.ffn_norm, layer.ffn_norm->n_dims, layer.ffn_norm->ne, &rnd);
+        randomize_tensor_normal(layer.ffn_norm, rnd);
 
-        randomize_tensor_normal(layer.w1, layer.w1->n_dims, layer.w1->ne, &rnd);
-        randomize_tensor_normal(layer.w2, layer.w2->n_dims, layer.w2->ne, &rnd);
-        randomize_tensor_normal(layer.w3, layer.w3->n_dims, layer.w3->ne, &rnd);
+        randomize_tensor_normal(layer.w1, rnd);
+        randomize_tensor_normal(layer.w2, rnd);
+        randomize_tensor_normal(layer.w3, rnd);
     }
+
+    free_random_normal_distribution(rnd);
 }
 
 
@@ -429,32 +359,34 @@ static void randomize_model_lora(
 
     const uint32_t n_layer = hparams.n_layer;
 
-    struct random_normal_distribution rnd;
-    init_random_normal_distribution(&rnd, seed, mean, std, min, max);
-    randomize_tensor_normal(model->tok_embeddings, model->tok_embeddings->n_dims, model->tok_embeddings->ne, &rnd);
-    randomize_tensor_normal(model->norm,           model->norm->n_dims,           model->norm->ne,           &rnd);
-    randomize_tensor_normal(model->outputa,        model->outputa->n_dims,        model->outputa->ne,         &rnd);
-    randomize_tensor_normal(model->outputb,        model->outputb->n_dims,        model->outputb->ne,         &rnd);
+    struct random_normal_distribution * rnd = init_random_normal_distribution(seed, mean, std, min, max);
+
+    randomize_tensor_normal(model->tok_embeddings, rnd);
+    randomize_tensor_normal(model->norm          , rnd);
+    randomize_tensor_normal(model->outputa       , rnd);
+    randomize_tensor_normal(model->outputb       , rnd);
 
     for (uint32_t i = 0; i < n_layer; ++i) {
         auto & layer = model->layers[i];
-        randomize_tensor_normal(layer.attention_norm, layer.attention_norm->n_dims, layer.attention_norm->ne, &rnd);
-
-        randomize_tensor_normal(layer.wqa, layer.wqa->n_dims, layer.wqa->ne, &rnd);
-        randomize_tensor_normal(layer.wqb, layer.wqb->n_dims, layer.wqb->ne, &rnd);
-        randomize_tensor_normal(layer.wka, layer.wka->n_dims, layer.wka->ne, &rnd);
-        randomize_tensor_normal(layer.wkb, layer.wkb->n_dims, layer.wkb->ne, &rnd);
-        randomize_tensor_normal(layer.wva, layer.wva->n_dims, layer.wva->ne, &rnd);
-        randomize_tensor_normal(layer.wvb, layer.wvb->n_dims, layer.wvb->ne, &rnd);
-        randomize_tensor_normal(layer.woa, layer.woa->n_dims, layer.woa->ne, &rnd);
-        randomize_tensor_normal(layer.wob, layer.wob->n_dims, layer.wob->ne, &rnd);
-
-        randomize_tensor_normal(layer.ffn_norm, layer.ffn_norm->n_dims, layer.ffn_norm->ne, &rnd);
-
-        randomize_tensor_normal(layer.w1, layer.w1->n_dims, layer.w1->ne, &rnd);
-        randomize_tensor_normal(layer.w2, layer.w2->n_dims, layer.w2->ne, &rnd);
-        randomize_tensor_normal(layer.w3, layer.w3->n_dims, layer.w3->ne, &rnd);
+        randomize_tensor_normal(layer.attention_norm, rnd);
+
+        randomize_tensor_normal(layer.wqa, rnd);
+        randomize_tensor_normal(layer.wqb, rnd);
+        randomize_tensor_normal(layer.wka, rnd);
+        randomize_tensor_normal(layer.wkb, rnd);
+        randomize_tensor_normal(layer.wva, rnd);
+        randomize_tensor_normal(layer.wvb, rnd);
+        randomize_tensor_normal(layer.woa, rnd);
+        randomize_tensor_normal(layer.wob, rnd);
+
+        randomize_tensor_normal(layer.ffn_norm, rnd);
+
+        randomize_tensor_normal(layer.w1, rnd);
+        randomize_tensor_normal(layer.w2, rnd);
+        randomize_tensor_normal(layer.w3, rnd);
     }
+
+    free_random_normal_distribution(rnd);
 }
 
 static bool init_kv_cache(struct llama_kv_cache* cache, struct llama_model * model, int n_batch) {
@@ -762,32 +694,6 @@ static struct ggml_tensor * forward(
     return inpL;
 }
 
-static void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) {
-    GGML_ASSERT(tensor->n_dims == 1);
-    GGML_ASSERT(tensor->ne[0] == ne0);
-}
-
-static void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1) {
-    GGML_ASSERT(tensor->n_dims == 2);
-    GGML_ASSERT(tensor->ne[0] == ne0);
-    GGML_ASSERT(tensor->ne[1] == ne1);
-}
-
-static void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2) {
-    GGML_ASSERT(tensor->n_dims == 3);
-    GGML_ASSERT(tensor->ne[0] == ne0);
-    GGML_ASSERT(tensor->ne[1] == ne1);
-    GGML_ASSERT(tensor->ne[2] == ne2);
-}
-
-static void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
-    GGML_ASSERT(tensor->n_dims == 4);
-    GGML_ASSERT(tensor->ne[0] == ne0);
-    GGML_ASSERT(tensor->ne[1] == ne1);
-    GGML_ASSERT(tensor->ne[2] == ne2);
-    GGML_ASSERT(tensor->ne[3] == ne3);
-}
-
 static struct ggml_tensor * forward_batch(
     struct llama_model    * model,
     struct llama_kv_cache * cache,

+ 5 - 0
examples/export-lora/CMakeLists.txt

@@ -0,0 +1,5 @@
+set(TARGET export-lora)
+add_executable(${TARGET} export-lora.cpp)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_11)

+ 26 - 0
examples/export-lora/README.md

@@ -0,0 +1,26 @@
+# export-lora
+
+Apply LORA adapters to base model and export the resulting model.
+
+```
+usage: export-lora [options]
+
+options:
+  -h, --help                         show this help message and exit
+  -m FNAME, --model-base FNAME       model path from which to load base model (default '')
+  -o FNAME, --model-out FNAME        path to save exported model (default '')
+  -l FNAME, --lora FNAME             apply LoRA adapter
+  -s FNAME S, --lora-scaled FNAME S  apply LoRA adapter with user defined scaling S
+  -t N, --threads N                  number of threads to use during computation (default: 4)
+```
+
+For example:
+
+```bash
+./bin/export-lora \
+    -m open-llama-3b-v2-q8_0.gguf \
+    -o open-llama-3b-v2-q8_0-english2tokipona-chat.gguf \
+    -l lora-open-llama-3b-v2-q8_0-english2tokipona-chat-LATEST.bin
+```
+
+Multiple LORA adapters can be applied by passing multiple `-l FN` or `-s FN S` command line parameters.

+ 474 - 0
examples/export-lora/export-lora.cpp

@@ -0,0 +1,474 @@
+
+#include "common.h"
+#include "ggml.h"
+#include "ggml-alloc.h"
+
+#include <vector>
+#include <string>
+#include <thread>
+
+static const size_t tensor_alignment = 32;
+
+struct lora_info {
+    std::string filename;
+    float scale;
+};
+
+struct export_lora_params {
+    std::string fn_model_base;
+    std::string fn_model_out;
+    std::vector<struct lora_info> lora;
+    int n_threads;
+};
+
+struct lora_data {
+    struct lora_info     info;
+    std::vector<uint8_t> data;
+    struct ggml_context * ctx;
+
+    uint32_t lora_r;
+    uint32_t lora_alpha;
+};
+
+struct llama_file {
+    // use FILE * so we don't have to re-open the file to mmap
+    FILE * fp;
+    size_t size;
+
+    llama_file(const char * fname, const char * mode) {
+        fp = std::fopen(fname, mode);
+        if (fp == NULL) {
+            size = 0;
+        } else {
+            seek(0, SEEK_END);
+            size = tell();
+            seek(0, SEEK_SET);
+        }
+    }
+
+    size_t tell() const {
+#ifdef _WIN32
+        __int64 ret = _ftelli64(fp);
+#else
+        long ret = std::ftell(fp);
+#endif
+        GGML_ASSERT(ret != -1); // this really shouldn't fail
+        return (size_t) ret;
+    }
+
+    void seek(size_t offset, int whence) {
+#ifdef _WIN32
+        int ret = _fseeki64(fp, (__int64) offset, whence);
+#else
+        int ret = std::fseek(fp, (long) offset, whence);
+#endif
+        GGML_ASSERT(ret == 0); // same
+    }
+
+    void read_raw(void * ptr, size_t size) {
+        if (size == 0) {
+            return;
+        }
+        errno = 0;
+        std::size_t ret = std::fread(ptr, size, 1, fp);
+        if (ferror(fp)) {
+            die_fmt("read error: %s", strerror(errno));
+        }
+        if (ret != 1) {
+            die("unexpectedly reached end of file");
+        }
+    }
+
+    std::uint32_t read_u32() {
+        std::uint32_t ret;
+        read_raw(&ret, sizeof(ret));
+        return ret;
+    }
+
+    std::string read_string(std::uint32_t len) {
+        std::vector<char> chars(len);
+        read_raw(chars.data(), len);
+        return std::string(chars.data(), len);
+    }
+
+    void write_raw(const void * ptr, size_t size) {
+        if (size == 0) {
+            return;
+        }
+        errno = 0;
+        size_t ret = std::fwrite(ptr, size, 1, fp);
+        if (ret != 1) {
+            die_fmt("write error: %s", strerror(errno));
+        }
+    }
+
+    void write_u32(std::uint32_t val) {
+        write_raw(&val, sizeof(val));
+    }
+
+    bool eof() {
+        return tell() >= size;
+    }
+
+    ~llama_file() {
+        if (fp) {
+            std::fclose(fp);
+        }
+    }
+};
+
+static struct export_lora_params get_default_export_lora_params() {
+    struct export_lora_params result;
+    result.fn_model_base = "";
+    result.fn_model_out  = "";
+    result.n_threads = GGML_DEFAULT_N_THREADS;
+    return result;
+}
+
+static void export_lora_print_usage(int /*argc*/, char ** argv, const struct export_lora_params * params) {
+    fprintf(stderr, "usage: %s [options]\n", argv[0]);
+    fprintf(stderr, "\n");
+    fprintf(stderr, "options:\n");
+    fprintf(stderr, "  -h, --help                         show this help message and exit\n");
+    fprintf(stderr, "  -m FNAME, --model-base FNAME       model path from which to load base model (default '%s')\n", params->fn_model_base.c_str());
+    fprintf(stderr, "  -o FNAME, --model-out FNAME        path to save exported model (default '%s')\n", params->fn_model_out.c_str());
+    fprintf(stderr, "  -l FNAME, --lora FNAME             apply LoRA adapter\n");
+    fprintf(stderr, "  -s FNAME S, --lora-scaled FNAME S  apply LoRA adapter with user defined scaling S\n");
+    fprintf(stderr, "  -t N, --threads N                  number of threads to use during computation (default: %d)\n", params->n_threads);
+}
+
+static bool export_lora_params_parse(int argc, char ** argv, struct export_lora_params * params) {
+    bool invalid_param = false;
+    std::string arg;
+    struct export_lora_params default_params = get_default_export_lora_params();
+    const std::string arg_prefix = "--";
+
+    for (int i = 1; i < argc; i++) {
+        arg = argv[i];
+        if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
+            std::replace(arg.begin(), arg.end(), '_', '-');
+        }
+
+        if (arg == "-m" || arg == "--model-base") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params->fn_model_base = argv[i];
+        } else if (arg == "-o" || arg == "--model-out") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params->fn_model_out = argv[i];
+        } else if (arg == "-l" || arg == "--lora") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            struct lora_info lora;
+            lora.filename = argv[i];
+            lora.scale = 1.0f;
+            params->lora.push_back(lora);
+        } else if (arg == "-s" || arg == "--lora-scaled") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            struct lora_info lora;
+            lora.filename = argv[i];
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            lora.scale = std::stof(argv[i]);
+            params->lora.push_back(lora);
+        } else if (arg == "-t" || arg == "--threads") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params->n_threads = std::stoi(argv[i]);
+            if (params->n_threads <= 0) {
+                params->n_threads = std::thread::hardware_concurrency();
+            }
+        } else {
+            fprintf(stderr, "error: unknown argument: '%s'\n", arg.c_str());
+            export_lora_print_usage(argc, argv, &default_params);
+            exit(1);
+        }
+    }
+
+    if (params->fn_model_base == default_params.fn_model_base) {
+        fprintf(stderr, "error: please specify a filename for model-base.\n");
+        export_lora_print_usage(argc, argv, &default_params);
+        exit(1);
+    }
+    if (params->fn_model_out == default_params.fn_model_out) {
+        fprintf(stderr, "error: please specify a filename for model-out.\n");
+        export_lora_print_usage(argc, argv, &default_params);
+        exit(1);
+    }
+    if (invalid_param) {
+        fprintf(stderr, "error: invalid parameter for argument: '%s'\n", arg.c_str());
+        export_lora_print_usage(argc, argv, &default_params);
+        exit(1);
+    }
+    return true;
+}
+
+static void free_lora(struct lora_data * lora) {
+    if (lora->ctx != NULL) {
+        ggml_free(lora->ctx);
+    }
+    delete lora;
+}
+
+static struct lora_data * load_lora(struct lora_info * info) {
+    struct lora_data * result = new struct lora_data;
+    result->info = *info;
+    result->ctx = NULL;
+    result->lora_r     = 1;
+    result->lora_alpha = 1;
+
+    struct llama_file file(info->filename.c_str(), "rb");
+    if (file.fp == NULL) {
+        fprintf(stderr, "warning: Could not open lora adapter '%s'. Ignoring this adapter.\n",
+            info->filename.c_str());
+        free_lora(result);
+        return NULL;
+    }
+
+    struct ggml_init_params params_ggml;
+    params_ggml.mem_size   = ggml_tensor_overhead() * GGML_MAX_NODES;
+    params_ggml.mem_buffer = NULL;
+    params_ggml.no_alloc   = true;
+    result->ctx = ggml_init(params_ggml);
+
+    uint32_t LLAMA_FILE_MAGIC_LORA = 0x67676C61; // 'ggla'
+    uint32_t magic   = file.read_u32();
+    if (magic != LLAMA_FILE_MAGIC_LORA) {
+        die_fmt("unexpected lora header file magic in '%s'", info->filename.c_str());
+    }
+    uint32_t version = file.read_u32();
+    if (version != 1) {
+        die_fmt("unexpected lora file version '%u' in '%s'", (unsigned) version, info->filename.c_str());
+    }
+    result->lora_r     = file.read_u32();
+    result->lora_alpha = file.read_u32();
+    // read tensor infos from file
+    std::vector<char> name_buf;
+    std::vector<struct ggml_tensor *> tensors;
+    std::vector<size_t> tensors_offset;
+    size_t total_nbytes_pad = 0;
+    while(!file.eof()) {
+        int64_t ne[4]   = {1,1,1,1};
+        uint32_t n_dims  = file.read_u32();
+        uint32_t namelen = file.read_u32();
+        uint32_t type    = file.read_u32();
+        for (uint32_t k = 0; k < n_dims; ++k) {
+            ne[k] = (int64_t)file.read_u32();
+        }
+        name_buf.clear();
+        name_buf.resize(namelen + 1, '\0');
+        file.read_raw(name_buf.data(), namelen);
+        file.seek((0-file.tell()) & 31, SEEK_CUR);
+        size_t offset = file.tell();
+        struct ggml_tensor * tensor = ggml_new_tensor(result->ctx, (enum ggml_type) type, n_dims, ne);
+        ggml_set_name(tensor, name_buf.data());
+        size_t nbytes     = ggml_nbytes(tensor);
+        size_t nbytes_pad = ggml_nbytes_pad(tensor);
+        total_nbytes_pad += nbytes_pad;
+        tensors.push_back(tensor);
+        tensors_offset.push_back(offset);
+        file.seek(nbytes, SEEK_CUR);
+    }
+    // read tensor data
+    result->data.resize(total_nbytes_pad);
+    size_t data_offset = 0;
+    for (size_t i = 0; i < tensors.size(); ++i) {
+        struct ggml_tensor * tensor = tensors[i];
+        size_t offset     = tensors_offset[i];
+        size_t nbytes     = ggml_nbytes(tensor);
+        size_t nbytes_pad = ggml_nbytes_pad(tensor);
+        file.seek(offset, SEEK_SET);
+        tensor->data = result->data.data() + data_offset;
+        file.read_raw(tensor->data, nbytes);
+        data_offset += nbytes_pad;
+    }
+    return result;
+}
+
+
+static struct ggml_cgraph * build_graph_lora(
+    struct ggml_context * ctx,
+    struct ggml_tensor * tensor,
+    struct ggml_tensor * lora_a,
+    struct ggml_tensor * lora_b,
+    float scaling
+) {
+    struct ggml_tensor * ab = ggml_mul_mat(ctx, lora_a, lora_b);
+    if (scaling != 1.0f) {
+        ab = ggml_scale(ctx, ab, ggml_new_f32(ctx, scaling));
+    }
+    struct ggml_tensor * res = ggml_add_inplace(ctx, tensor, ab);
+
+    struct ggml_cgraph * gf = ggml_new_graph(ctx);
+    ggml_build_forward_expand (gf, res);
+    return gf;
+}
+
+static bool apply_lora(struct ggml_tensor * tensor, struct lora_data * lora, int n_threads) {
+    if (lora->ctx == NULL) {
+        return false;
+    }
+    std::string name = ggml_get_name(tensor);
+    std::string name_a = name + std::string(".loraA");
+    std::string name_b = name + std::string(".loraB");
+    struct ggml_tensor * lora_a = ggml_get_tensor(lora->ctx, name_a.c_str());
+    struct ggml_tensor * lora_b = ggml_get_tensor(lora->ctx, name_b.c_str());
+    if (lora_a == NULL || lora_b == NULL) {
+        return false;
+    }
+
+    float scaling = lora->info.scale * (float)lora->lora_alpha / (float)lora->lora_r;
+
+    struct ggml_init_params params;
+    params.mem_size   = GGML_OBJECT_SIZE + GGML_GRAPH_SIZE + ggml_tensor_overhead()*4 + GGML_MEM_ALIGN*5;
+    params.mem_buffer = NULL;
+    params.no_alloc   = true;
+    struct ggml_context * ctx = NULL;
+    struct ggml_allocr * alloc = NULL;
+    struct ggml_cgraph * gf = NULL;
+
+    ctx   = ggml_init(params);
+    alloc = ggml_allocr_new_measure(tensor_alignment);
+    gf    = build_graph_lora(ctx, tensor, lora_a, lora_b, scaling);
+    size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf);
+    ggml_allocr_free(alloc);
+    ggml_free(ctx);
+
+    static std::vector<uint8_t> data_compute;
+    data_compute.resize(alloc_size + tensor_alignment);
+
+    ctx   = ggml_init(params);
+    alloc = ggml_allocr_new(data_compute.data(), data_compute.size(), tensor_alignment);
+    gf    = build_graph_lora(ctx, tensor, lora_a, lora_b, scaling);
+    ggml_allocr_alloc_graph(alloc, gf);
+    ggml_allocr_free(alloc);
+
+    struct ggml_cplan cplan = ggml_graph_plan(gf, n_threads);
+    static std::vector<uint8_t> data_work;
+    data_work.resize(cplan.work_size);
+    cplan.work_data = data_work.data();
+
+    ggml_graph_compute(gf, &cplan);
+
+    ggml_free(ctx);
+    return true;
+}
+
+static void export_lora(struct export_lora_params * params) {
+    // load all loras
+    std::vector<struct lora_data *> loras;
+    for (size_t i = 0; i < params->lora.size(); ++i) {
+        struct lora_data * lora = load_lora(&params->lora[i]);
+        if (lora != NULL) {
+            loras.push_back(lora);
+        }
+    }
+    if (loras.size() == 0) {
+        fprintf(stderr, "warning: no lora adapters will be applied.\n");
+    }
+
+    // open input file
+    struct llama_file fin(params->fn_model_base.c_str(), "rb");
+    if (!fin.fp) {
+        die_fmt("Could not open file '%s'\n", params->fn_model_base.c_str());
+    }
+
+    // open base model gguf, read tensors without their data
+    struct ggml_context * ctx_in;
+    struct gguf_init_params params_gguf;
+    params_gguf.no_alloc = true;
+    params_gguf.ctx      = &ctx_in;
+    struct gguf_context * gguf_in = gguf_init_from_file(params->fn_model_base.c_str(), params_gguf);
+
+    // create new gguf
+    struct gguf_context * gguf_out = gguf_init_empty();
+
+    // copy meta data from base model: kv and tensors
+    gguf_set_kv(gguf_out, gguf_in);
+    int n_tensors = gguf_get_n_tensors(gguf_in);
+    for (int i=0; i < n_tensors; ++i) {
+        const char * name = gguf_get_tensor_name(gguf_in, i);
+        struct ggml_tensor * tensor = ggml_get_tensor(ctx_in, name);
+        gguf_add_tensor(gguf_out, tensor);
+    }
+
+    // create output file
+    struct llama_file fout(params->fn_model_out.c_str(), "wb");
+    if (!fout.fp) {
+        die_fmt("Could not create file '%s'\n", params->fn_model_out.c_str());
+    }
+
+    // write gguf meta data
+    std::vector<uint8_t> meta;
+    meta.resize(gguf_get_meta_size(gguf_out));
+    gguf_get_meta_data(gguf_out, meta.data());
+    fout.write_raw(meta.data(), meta.size());
+
+    std::vector<uint8_t> data;
+    std::vector<uint8_t> padding;
+    for (int i=0; i < n_tensors; ++i) {
+        const char * name = gguf_get_tensor_name(gguf_in, i);
+        struct ggml_tensor * tensor = ggml_get_tensor(ctx_in, name);
+
+        // read tensor data
+        data.resize(ggml_nbytes(tensor));
+        tensor->data = data.data();
+        size_t offset = gguf_get_tensor_offset(gguf_in, i);
+        fin.seek(offset + meta.size(), SEEK_SET);
+        fin.read_raw(data.data(), data.size());
+
+        // apply all loras
+        for (size_t k = 0; k < loras.size(); ++k) {
+            apply_lora(tensor, loras[k], params->n_threads);
+        }
+
+        // write tensor data + padding
+        padding.clear();
+        padding.resize(GGML_PAD(data.size(), gguf_get_alignment(gguf_out)) - data.size(), 0);
+
+        GGML_ASSERT(fout.tell() == offset + meta.size());
+        // fout.seek(offset + meta.size(), SEEK_SET);
+        fout.write_raw(data.data(), data.size());
+        fout.write_raw(padding.data(), padding.size());
+
+        if (i % 2 == 0) {
+            printf(".");
+        }
+    }
+    printf("\n");
+
+    // close gguf
+    gguf_free(gguf_out);
+    gguf_free(gguf_in);
+
+    // free loras
+    for (size_t i = 0; i < loras.size(); ++i) {
+        free_lora(loras[i]);
+    }
+}
+
+int main(int argc, char ** argv) {
+    struct export_lora_params params = get_default_export_lora_params();
+
+    if (!export_lora_params_parse(argc, argv, &params)) {
+        return 1;
+    }
+
+    export_lora(&params);
+
+    return 0;
+}

+ 5 - 0
examples/finetune/CMakeLists.txt

@@ -0,0 +1,5 @@
+set(TARGET finetune)
+add_executable(${TARGET} finetune.cpp)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_11)

+ 90 - 0
examples/finetune/README.md

@@ -0,0 +1,90 @@
+# finetune
+
+Basic usage instructions:
+
+```bash
+# get training data
+wget https://raw.githubusercontent.com/brunoklein99/deep-learning-notes/master/shakespeare.txt
+
+# finetune LORA adapter
+./bin/finetune \
+        --model-base open-llama-3b-v2-q8_0.gguf \
+        --checkpoint-in  chk-lora-open-llama-3b-v2-q8_0-shakespeare-LATEST.gguf \
+        --checkpoint-out chk-lora-open-llama-3b-v2-q8_0-shakespeare-ITERATION.gguf \
+        --lora-out lora-open-llama-3b-v2-q8_0-shakespeare-ITERATION.bin \
+        --train-data "shakespeare.txt" \
+        --save-every 10 \
+        --threads 6 --adam-iter 30 --batch 4 --ctx 64 \
+        --use-checkpointing
+
+# predict
+./bin/main -m open-llama-3b-v2-q8_0.gguf --lora lora-open-llama-3b-v2-q8_0-shakespeare-LATEST.bin
+```
+
+Finetune output files will be saved every N iterations (config with `--save-every N`).
+The pattern 'ITERATION' in the output filenames will be replaced with the iteration number and with 'LATEST' for the latest output.
+So in above example after 10 iterations these files will be written:
+- chk-lora-open-llama-3b-v2-q8_0-shakespeare-10.gguf
+- chk-lora-open-llama-3b-v2-q8_0-shakespeare-LATEST.gguf
+- lora-open-llama-3b-v2-q8_0-shakespeare-10.bin
+- lora-open-llama-3b-v2-q8_0-shakespeare-LATEST.bin
+
+After 10 more iterations:
+- chk-lora-open-llama-3b-v2-q8_0-shakespeare-20.gguf
+- chk-lora-open-llama-3b-v2-q8_0-shakespeare-LATEST.gguf
+- lora-open-llama-3b-v2-q8_0-shakespeare-20.bin
+- lora-open-llama-3b-v2-q8_0-shakespeare-LATEST.bin
+
+Checkpoint files (`--checkpoint-in FN`, `--checkpoint-out FN`) store the training process. When the input checkpoint file does not exist, it will begin finetuning a new randomly initialized adapter.
+
+llama.cpp compatible LORA adapters will be saved with filename specified by `--lora-out FN`.
+These LORA adapters can then be used by `main` together with the base model, like in the 'predict' example command above.
+
+In `main` you can also load multiple LORA adapters, which will then be mixed together.
+
+For example if you have two LORA adapters `lora-open-llama-3b-v2-q8_0-shakespeare-LATEST.bin` and `lora-open-llama-3b-v2-q8_0-bible-LATEST.bin`, you can mix them together like this:
+
+```bash
+./bin/main -m open-llama-3b-v2-q8_0.gguf \
+  --lora lora-open-llama-3b-v2-q8_0-shakespeare-LATEST.bin \
+  --lora lora-open-llama-3b-v2-q8_0-bible-LATEST.bin
+```
+
+You can change how strong each LORA adapter is applied to the base model by using `--lora-scaled FN SCALE` instead of `--lora FN`.
+
+For example to apply 40% of the 'shakespeare' LORA adapter, 80% of the 'bible' LORA adapter and 100% of yet another one:
+
+```bash
+./bin/main -m open-llama-3b-v2-q8_0.gguf \
+  --lora-scaled lora-open-llama-3b-v2-q8_0-shakespeare-LATEST.bin 0.4 \
+  --lora-scaled lora-open-llama-3b-v2-q8_0-bible-LATEST.bin 0.8 \
+  --lora lora-open-llama-3b-v2-q8_0-yet-another-one-LATEST.bin
+```
+
+The scale numbers don't need to add up to one, and you can also use numbers creater than 1 to further increase the influence of an adapter. But making the values to big will sometimes result in worse output. Play around to find good values.
+
+Gradient checkpointing reduces the memory requirements by ~50% but increases the runtime.
+If you have enough RAM, you can make finetuning a bit faster by disabling checkpointing with `--no-checkpointing`.
+
+The default LORA rank can be specified with `--lora-r N`.
+The LORA rank can be configured for each model tensor type separately with these command line options:
+
+```bash
+  --lora-r N                 LORA r: default rank. Also specifies resulting scaling together with lora-alpha. (default 4)
+  --rank-att-norm N          LORA rank for attention norm tensor (default 1)
+  --rank-ffn-norm N          LORA rank for feed-forward norm tensor (default 1)
+  --rank-out-norm N          LORA rank for output norm tensor (default 1)
+  --rank-tok-embd N          LORA rank for token embeddings tensor (default 4)
+  --rank-out N               LORA rank for output tensor (default 4)
+  --rank-wq N                LORA rank for wq tensor (default 4)
+  --rank-wk N                LORA rank for wk tensor (default 4)
+  --rank-wv N                LORA rank for wv tensor (default 4)
+  --rank-wo N                LORA rank for wo tensor (default 4)
+  --rank-w1 N                LORA rank for w1 tensor (default 4)
+  --rank-w2 N                LORA rank for w2 tensor (default 4)
+  --rank-w3 N                LORA rank for w3 tensor (default 4)
+```
+
+The LORA rank of 'norm' tensors should always be 1.
+
+To see all available options use `finetune --help`.

+ 489 - 0
examples/finetune/convert-finetune-checkpoint-to-gguf.py

@@ -0,0 +1,489 @@
+#!/usr/bin/env python3
+# finetune checkpoint --> gguf conversion
+
+import argparse
+import gguf
+import os
+import struct
+import sys
+import numpy as np
+from pathlib import Path
+
+# gguf constants
+LLM_KV_OPTIMIZER_TYPE = "optimizer.type"
+LLM_KV_OPTIMIZER_TYPE_ADAM  = "adam"
+LLM_KV_OPTIMIZER_TYPE_LBFGS = "lbfgs"
+LLM_KV_OPTIMIZER_FILE_VERSION               = "optimizer.file_version"
+LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT     = "optimizer.convergence_past_count"
+LLM_KV_OPTIMIZER_PARAMETER_COUNT            = "optimizer.parameter_count"
+LLM_KV_OPTIMIZER_ITERATION_COUNT            = "optimizer.iteration_count"
+LLM_KV_OPTIMIZER_JUST_INITIALIZED           = "optimizer.just_initialized"
+LLM_KV_OPTIMIZER_ADAM_BEST_LOSS             = "optimizer.adam.best_loss"
+LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS         = "optimizer.adam.previous_loss"
+LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT  = "optimizer.adam.no_improvement_count"
+LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT = "optimizer.lbfgs.approx_hessian_count"
+LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS            = "optimizer.lbfgs.best_loss"
+LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP     = "optimizer.lbfgs.line_search_step"
+LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J        = "optimizer.lbfgs.line_search_j"
+LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K        = "optimizer.lbfgs.line_search_k"
+LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END      = "optimizer.lbfgs.line_search_end"
+LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT = "optimizer.lbfgs.no_improvement_count"
+
+LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS    = "optimizer.adam.first_moments"
+LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS   = "optimizer.adam.second_moments"
+LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES = "optimizer.adam.past_loss_values"
+
+LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS  = "optimizer.lbfgs.current_parameters"
+LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS = "optimizer.lbfgs.previous_parameters"
+LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS   = "optimizer.lbfgs.current_gradients"
+LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS  = "optimizer.lbfgs.previous_gradients"
+LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION    = "optimizer.lbfgs.search_direction"
+LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES    = "optimizer.lbfgs.past_loss_values"
+LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA        = "optimizer.lbfgs.memory_alpha"
+LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS           = "optimizer.lbfgs.memory_ys"
+LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S            = "optimizer.lbfgs.memory_s"
+LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y            = "optimizer.lbfgs.memory_y"
+
+LLM_KV_TRAINING_TYPE_TRAIN_MODEL   = "train_model"
+LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora"
+LLM_KV_TRAINING_TYPE               = "training.type"
+LLM_KV_TRAINING_FILE_VERSION       = "training.file_version"
+LLM_KV_TRAINING_ITERATION_COUNT    = "training.iteration_count"
+LLM_KV_TRAINING_SAMPLE_COUNT       = "training.sample_count"
+LLM_KV_TRAINING_TOKEN_COUNT        = "training.token_count"
+
+LLM_KV_TRAINING_LORA_RANK_TOKEN_EMBD  = "training.lora.rank.token_embd"
+LLM_KV_TRAINING_LORA_RANK_OUTPUT_NORM = "training.lora.rank.output_norm"
+LLM_KV_TRAINING_LORA_RANK_OUTPUT      = "training.lora.rank.output"
+LLM_KV_TRAINING_LORA_RANK_ATTN_NORM   = "training.lora.rank.attn_norm"
+LLM_KV_TRAINING_LORA_RANK_ATTN_Q      = "training.lora.rank.attn_q"
+LLM_KV_TRAINING_LORA_RANK_ATTN_K      = "training.lora.rank.attn_k"
+LLM_KV_TRAINING_LORA_RANK_ATTN_V      = "training.lora.rank.attn_v"
+LLM_KV_TRAINING_LORA_RANK_ATTN_OUT    = "training.lora.rank.attn_output"
+LLM_KV_TRAINING_LORA_RANK_FFN_NORM    = "training.lora.rank.ffn_norm"
+LLM_KV_TRAINING_LORA_RANK_FFN_GATE    = "training.lora.rank.ffn_gate"
+LLM_KV_TRAINING_LORA_RANK_FFN_DOWN    = "training.lora.rank.ffn_down"
+LLM_KV_TRAINING_LORA_RANK_FFN_UP      = "training.lora.rank.ffn_up"
+
+class Tensor:
+    def __init__(self, dtype='f', ne=None):
+        if ne is None:
+            ne = []
+        self.dtype = dtype
+        self.ne = ne
+        self.nbytes = 0
+        if self.dtype == 'f':
+            if len(self.ne) == 0:
+                self.nbytes = 0
+            else:
+                self.nbytes = int(np.product(self.ne)) * 4
+        else:
+            raise ValueError(f"Unhandled data type '{self.dtype}'")
+
+    def load(self, data, offset):
+        nd = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+        namelen = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+        dtype = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+
+        assert(nd == len(self.ne))
+        ne = []
+        for d in range(nd):
+            n = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+            ne.append(n)
+
+        if tuple(ne) != tuple(self.ne):
+            raise ValueError(f"Tensor.load: Expected number of elements {str(self.ne)} does not match what is read from file {str(ne)}")
+
+        if self.dtype == 'f':
+            assert(dtype == 0)
+        else:
+            raise ValueError(f"Unhandled data type '{self.dtype}'")
+
+        self.name = bytes(data[offset:offset+namelen]); offset += namelen
+        # 32-byte alignment
+        offset += (0 - offset) & 31
+        self.data = data[offset:offset+self.nbytes]
+        offset += self.nbytes
+        return offset
+
+    def max_storage_size(self):
+        result = 0
+        result += 4 # nd
+        result += 4 # namelen
+        result += 4 # dtype
+        result += len(self.ne)*8 # ne
+        result += 48 # name (maximum as of commit 3b5515bbe0e2224425986ba24f1f5d84aa38dce9)
+        result += 31 # 32-byte alignment
+        result += self.nbytes
+        return result
+
+    def save_gguf(self, gguf_writer, name):
+        gguf_writer.add_tensor(
+            name=name,
+            tensor=self.data,
+            raw_shape=np.array(list(reversed(self.ne))),
+            raw_dtype=gguf.GGMLQuantizationType.F32)
+
+class OptimizationContext:
+    def __init__(self):
+        pass
+
+    def load(self, data, offset):
+        self.version = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]
+        offset += 4
+
+        if self.version != 1:
+            raise ValueError('Invalid version of optimization context in checkpoint file')
+
+        self.past    = struct.unpack('<i', bytes(data[offset:offset + 4]))[0];  offset += 4
+        self.lbfgs_m = struct.unpack('<i', bytes(data[offset:offset + 4]))[0];  offset += 4
+        self.nx      = struct.unpack('N',  bytes(data[offset:offset + 8]))[0];  offset += 8
+        self.iter    = struct.unpack('<i', bytes(data[offset:offset + 4]))[0];  offset += 4
+        self.just_initialized = bool(struct.unpack('<i', bytes(data[offset:offset + 4]))[0]);  offset += 4
+
+        self.adam_m  = Tensor('f', [self.nx])
+        self.adam_v  = Tensor('f', [self.nx])
+        self.adam_pf = Tensor('f', [self.past] if self.past > 0 else [])
+
+        self.lbfgs_x    = Tensor('f', [self.nx])
+        self.lbfgs_xp   = Tensor('f', [self.nx])
+        self.lbfgs_g    = Tensor('f', [self.nx])
+        self.lbfgs_gp   = Tensor('f', [self.nx])
+        self.lbfgs_d    = Tensor('f', [self.nx])
+        self.lbfgs_pf   = Tensor('f', [self.past] if self.past > 0 else [])
+        self.lbfgs_lmal = Tensor('f', [self.lbfgs_m])
+        self.lbfgs_lmys = Tensor('f', [self.lbfgs_m])
+        self.lbfgs_lms  = Tensor('f', [self.nx, self.lbfgs_m])
+        self.lbfgs_lmy  = Tensor('f', [self.nx, self.lbfgs_m])
+
+        # forgot to save type in version 1:
+        # guess self.type from number of remaining bytes
+        size_type_0 = 12 + sum([t.max_storage_size() for t in
+                                [self.adam_m, self.adam_v]
+                                +([self.adam_pf] if (self.past > 0) else [])])
+        size_type_1 = 24 + sum([t.max_storage_size() for t in
+                                [self.lbfgs_x, self.lbfgs_xp, self.lbfgs_g,
+                                 self.lbfgs_gp, self.lbfgs_d, self.lbfgs_pf,
+                                 self.lbfgs_lmal, self.lbfgs_lmys,
+                                 self.lbfgs_lms, self.lbfgs_lmy]
+                                 +([self.lbfgs_pf] if (self.past > 0) else [])])
+        # due to alignment padding the size might not by exact
+        # but the difference in size for both types is significant,
+        # so we can just use whichever is closest
+        remaining = len(data) - offset
+        if abs(remaining - size_type_0) < abs(remaining - size_type_1):
+            self.type = 0
+        else:
+            self.type = 1
+
+        if self.type == 0:
+            offset = self.adam_m.load(data, offset)
+            offset = self.adam_v.load(data, offset)
+            offset = self.adam_pf.load(data,offset)
+
+            self.adam_fx_best          = struct.unpack('<f', bytes(data[offset:offset + 4]))[0];  offset += 4
+            self.adam_fx_prev          = struct.unpack('<f', bytes(data[offset:offset + 4]))[0];  offset += 4
+            self.adam_n_no_improvement = struct.unpack('<i', bytes(data[offset:offset + 4]))[0];  offset += 4
+
+        elif self.type == 1:
+            offset = self.lbfgs_x.load(data, offset)
+            offset = self.lbfgs_xp.load(data, offset)
+            offset = self.lbfgs_g.load(data, offset)
+            offset = self.lbfgs_gp.load(data, offset)
+            offset = self.lbfgs_d.load(data, offset)
+            offset = self.lbfgs_pf.load(data, offset)
+            offset = self.lbfgs_lmal.load(data, offset)
+            offset = self.lbfgs_lmys.load(data, offset)
+            offset = self.lbfgs_lms.load(data, offset)
+            offset = self.lbfgs_lmy.load(data, offset)
+
+            self.lbfgs_fx_best          = struct.unpack('<f', bytes(data[offset:offset + 4]))[0];  offset += 4
+            self.lbfgs_step             = struct.unpack('<f', bytes(data[offset:offset + 4]))[0];  offset += 4
+            self.lbfgs_j                = struct.unpack('<i', bytes(data[offset:offset + 4]))[0];  offset += 4
+            self.lbfgs_k                = struct.unpack('<i', bytes(data[offset:offset + 4]))[0];  offset += 4
+            self.lbfgs_end              = struct.unpack('<i', bytes(data[offset:offset + 4]))[0];  offset += 4
+            self.lbfgs_n_no_improvement = struct.unpack('<i', bytes(data[offset:offset + 4]))[0];  offset += 4
+
+        else:
+            raise ValueError(f"Invalid optimizer type '{self.type}'")
+
+        return offset
+
+    def save_gguf(self, gguf_writer):
+        gguf_writer.add_uint32(LLM_KV_OPTIMIZER_FILE_VERSION, 0)
+        gguf_writer.add_uint32(LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT, self.past)
+        gguf_writer.add_uint64(LLM_KV_OPTIMIZER_PARAMETER_COUNT, self.nx)
+        gguf_writer.add_uint32(LLM_KV_OPTIMIZER_ITERATION_COUNT, self.iter)
+        gguf_writer.add_bool(LLM_KV_OPTIMIZER_JUST_INITIALIZED, self.just_initialized)
+
+        if self.type == 0:
+            gguf_writer.add_string(LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_ADAM)
+            gguf_writer.add_float32(LLM_KV_OPTIMIZER_ADAM_BEST_LOSS, self.adam_fx_best)
+            gguf_writer.add_float32(LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS, self.adam_fx_prev)
+            gguf_writer.add_uint32(LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT, self.adam_n_no_improvement)
+
+            self.adam_m.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS)
+            self.adam_v.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS)
+            if self.past > 0:
+                self.adam_pf.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES)
+
+        elif self.type == 1:
+            gguf_writer.add_string(LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS)
+            gguf_writer.add_uint32(LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, self.lbfgs_m)
+            gguf_writer.add_float32(LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS, self.lbfgs_fx_best)
+            gguf_writer.add_float32(LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP, self.lbfgs_step)
+            gguf_writer.add_int32(LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J, self.lbfgs_j)
+            gguf_writer.add_int32(LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K, self.lbfgs_k)
+            gguf_writer.add_int32(LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END, self.lbfgs_end)
+            gguf_writer.add_uint32(LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT, self.lbfgs_n_no_improvement)
+
+            self.lbfgs_x.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS)
+            self.lbfgs_xp.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS)
+            self.lbfgs_g.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS)
+            self.lbfgs_gp.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS)
+            self.lbfgs_d.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION)
+            if self.past > 0:
+                self.lbfgs_pf.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES)
+            self.lbfgs_lmal.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA)
+            self.lbfgs_lmys.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS)
+            self.lbfgs_lms.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S)
+            self.lbfgs_lmy.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y)
+        else:
+            raise ValueError('Unknown optimizer type')
+
+class LoraParams:
+    def __init__(self):
+        pass
+
+    def load(self, data, offset):
+        self.n_rank_attention_norm  = struct.unpack('<I', bytes(data[offset:offset + 4]))[0];  offset += 4
+        self.n_rank_wq              = struct.unpack('<I', bytes(data[offset:offset + 4]))[0];  offset += 4
+        self.n_rank_wk              = struct.unpack('<I', bytes(data[offset:offset + 4]))[0];  offset += 4
+        self.n_rank_wv              = struct.unpack('<I', bytes(data[offset:offset + 4]))[0];  offset += 4
+        self.n_rank_wo              = struct.unpack('<I', bytes(data[offset:offset + 4]))[0];  offset += 4
+        self.n_rank_ffn_norm        = struct.unpack('<I', bytes(data[offset:offset + 4]))[0];  offset += 4
+        self.n_rank_w1              = struct.unpack('<I', bytes(data[offset:offset + 4]))[0];  offset += 4
+        self.n_rank_w2              = struct.unpack('<I', bytes(data[offset:offset + 4]))[0];  offset += 4
+        self.n_rank_w3              = struct.unpack('<I', bytes(data[offset:offset + 4]))[0];  offset += 4
+        self.n_rank_tok_embeddings  = struct.unpack('<I', bytes(data[offset:offset + 4]))[0];  offset += 4
+        self.n_rank_norm            = struct.unpack('<I', bytes(data[offset:offset + 4]))[0];  offset += 4
+        self.n_rank_output          = struct.unpack('<I', bytes(data[offset:offset + 4]))[0];  offset += 4
+        return offset
+
+    def save_gguf(self, gguf_writer):
+        gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_TOKEN_EMBD,  self.n_rank_tok_embeddings)
+        gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_OUTPUT_NORM, self.n_rank_norm)
+        gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_OUTPUT,      self.n_rank_output)
+        gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_ATTN_NORM,   self.n_rank_attention_norm)
+        gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_ATTN_Q,      self.n_rank_wq)
+        gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_ATTN_K,      self.n_rank_wk)
+        gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_ATTN_V,      self.n_rank_wv)
+        gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_ATTN_OUT,    self.n_rank_wo)
+        gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_FFN_NORM,    self.n_rank_ffn_norm)
+        gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_FFN_GATE,    self.n_rank_w1)
+        gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_FFN_DOWN,    self.n_rank_w2)
+        gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_FFN_UP,      self.n_rank_w3)
+
+class ModelParams:
+    def __init__(self, n_ff = None):
+        self.n_ff = n_ff
+
+    def load(self, data, offset):
+        self.n_vocab = struct.unpack('<I', bytes(data[offset:offset + 4]))[0];  offset += 4
+        self.n_embd  = struct.unpack('<I', bytes(data[offset:offset + 4]))[0];  offset += 4
+        self.n_mult  = struct.unpack('<I', bytes(data[offset:offset + 4]))[0];  offset += 4
+        self.n_head  = struct.unpack('<I', bytes(data[offset:offset + 4]))[0];  offset += 4
+        self.n_layer = struct.unpack('<I', bytes(data[offset:offset + 4]))[0];  offset += 4
+        self.n_rot   = struct.unpack('<I', bytes(data[offset:offset + 4]))[0];  offset += 4
+        return offset
+
+    def get_n_ff(self):
+        if self.n_ff is None:
+            # struct my_llama_model::get_n_ff in train-text-from-scratch.cpp commit 3b5515bbe0e2224425986ba24f1f5d84aa38dce9
+            return ((2*(4*self.n_embd)//3 + self.n_mult - 1)//self.n_mult)*self.n_mult
+        else:
+            return self.n_ff
+
+    def save_gguf(self, gguf_writer):
+        # self.n_vocab not saved
+        gguf_writer.add_embedding_length(self.n_embd)
+        gguf_writer.add_head_count(self.n_head)
+        gguf_writer.add_block_count(self.n_layer)
+        gguf_writer.add_rope_dimension_count(self.n_rot)
+        gguf_writer.add_feed_forward_length(self.get_n_ff())
+
+def tensor_name(key, bid=None, suffix=".weight"):
+    return gguf.MODEL_TENSOR_NAMES[gguf.MODEL_ARCH.LLAMA][key].format(bid=bid) + suffix
+
+class Layer:
+    def __init__(self, params, lora_params, bid):
+        self.bid = bid
+        self.att_norm_a = Tensor('f', [lora_params.n_rank_attention_norm, params.n_embd])
+        self.att_norm_b = Tensor('f', [lora_params.n_rank_attention_norm, 1])
+        self.wq_a       = Tensor('f', [lora_params.n_rank_wq, params.n_embd])
+        self.wq_b       = Tensor('f', [lora_params.n_rank_wq, params.n_embd])
+        self.wk_a       = Tensor('f', [lora_params.n_rank_wk, params.n_embd])
+        self.wk_b       = Tensor('f', [lora_params.n_rank_wk, params.n_embd])
+        self.wv_a       = Tensor('f', [lora_params.n_rank_wv, params.n_embd])
+        self.wv_b       = Tensor('f', [lora_params.n_rank_wv, params.n_embd])
+        self.wo_a       = Tensor('f', [lora_params.n_rank_wo, params.n_embd])
+        self.wo_b       = Tensor('f', [lora_params.n_rank_wo, params.n_embd])
+        self.ffn_norm_a = Tensor('f', [lora_params.n_rank_ffn_norm, params.n_embd])
+        self.ffn_norm_b = Tensor('f', [lora_params.n_rank_ffn_norm, 1])
+        self.w1_a       = Tensor('f', [lora_params.n_rank_w1, params.n_embd])
+        self.w1_b       = Tensor('f', [lora_params.n_rank_w1, params.get_n_ff()])
+        self.w2_a       = Tensor('f', [lora_params.n_rank_w2, params.get_n_ff()])
+        self.w2_b       = Tensor('f', [lora_params.n_rank_w2, params.n_embd])
+        self.w3_a       = Tensor('f', [lora_params.n_rank_w3, params.n_embd])
+        self.w3_b       = Tensor('f', [lora_params.n_rank_w3, params.get_n_ff()])
+
+    def load(self, data, offset):
+        offset = self.att_norm_a.load(data, offset)
+        offset = self.att_norm_b.load(data, offset)
+        offset = self.wq_a.load(data, offset)
+        offset = self.wq_b.load(data, offset)
+        offset = self.wk_a.load(data, offset)
+        offset = self.wk_b.load(data, offset)
+        offset = self.wv_a.load(data, offset)
+        offset = self.wv_b.load(data, offset)
+        offset = self.wo_a.load(data, offset)
+        offset = self.wo_b.load(data, offset)
+        offset = self.ffn_norm_a.load(data, offset)
+        offset = self.ffn_norm_b.load(data, offset)
+        offset = self.w1_a.load(data, offset)
+        offset = self.w1_b.load(data, offset)
+        offset = self.w2_a.load(data, offset)
+        offset = self.w2_b.load(data, offset)
+        offset = self.w3_a.load(data, offset)
+        offset = self.w3_b.load(data, offset)
+        return offset
+
+    def save_gguf(self, gguf_writer):
+        self.att_norm_a.save_gguf(gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.ATTN_NORM, self.bid, ".weight.lora_a"))
+        self.att_norm_b.save_gguf(gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.ATTN_NORM, self.bid, ".weight.lora_b"))
+        self.wq_a.save_gguf      (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.ATTN_Q,    self.bid, ".weight.lora_a"))
+        self.wq_b.save_gguf      (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.ATTN_Q,    self.bid, ".weight.lora_b"))
+        self.wk_a.save_gguf      (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.ATTN_K,    self.bid, ".weight.lora_a"))
+        self.wk_b.save_gguf      (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.ATTN_K,    self.bid, ".weight.lora_b"))
+        self.wv_a.save_gguf      (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.ATTN_V,    self.bid, ".weight.lora_a"))
+        self.wv_b.save_gguf      (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.ATTN_V,    self.bid, ".weight.lora_b"))
+        self.wo_a.save_gguf      (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.ATTN_OUT,  self.bid, ".weight.lora_a"))
+        self.wo_b.save_gguf      (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.ATTN_OUT,  self.bid, ".weight.lora_b"))
+        self.ffn_norm_a.save_gguf(gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.FFN_NORM,  self.bid, ".weight.lora_a"))
+        self.ffn_norm_b.save_gguf(gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.FFN_NORM,  self.bid, ".weight.lora_b"))
+        self.w1_a.save_gguf      (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.FFN_GATE,  self.bid, ".weight.lora_a"))
+        self.w1_b.save_gguf      (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.FFN_GATE,  self.bid, ".weight.lora_b"))
+        self.w2_a.save_gguf      (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.FFN_DOWN,  self.bid, ".weight.lora_a"))
+        self.w2_b.save_gguf      (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.FFN_DOWN,  self.bid, ".weight.lora_b"))
+        self.w3_a.save_gguf      (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.FFN_UP,    self.bid, ".weight.lora_a"))
+        self.w3_b.save_gguf      (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.FFN_UP,    self.bid, ".weight.lora_b"))
+
+class LoraModel:
+    def __init__(self, n_ff = None):
+        self.params = ModelParams(n_ff = n_ff)
+        self.lora_params = LoraParams()
+        self.layers = []
+
+    def load(self, data, offset):
+        offset = self.params.load(data, offset)
+        offset = self.lora_params.load(data, offset)
+
+        self.tok_embd_a = Tensor('f', [self.lora_params.n_rank_tok_embeddings, self.params.n_embd])
+        self.tok_embd_b = Tensor('f', [self.lora_params.n_rank_tok_embeddings, self.params.n_vocab])
+        self.norm_a     = Tensor('f', [self.lora_params.n_rank_norm, self.params.n_embd])
+        self.norm_b     = Tensor('f', [self.lora_params.n_rank_norm, 1])
+        self.output_a   = Tensor('f', [self.lora_params.n_rank_output, self.params.n_embd])
+        self.output_b   = Tensor('f', [self.lora_params.n_rank_output, self.params.n_vocab])
+
+        offset = self.tok_embd_a.load(data, offset)
+        offset = self.tok_embd_b.load(data, offset)
+        offset = self.norm_a.load(data, offset)
+        offset = self.norm_b.load(data, offset)
+        offset = self.output_a.load(data, offset)
+        offset = self.output_b.load(data, offset)
+
+        self.layers.clear()
+        for bid in range(self.params.n_layer):
+            layer = Layer(self.params, self.lora_params, bid)
+            offset = layer.load(data, offset)
+            self.layers.append(layer)
+
+        return offset
+
+    def save_gguf(self, gguf_writer):
+        self.params.save_gguf(gguf_writer)
+        self.lora_params.save_gguf(gguf_writer)
+
+        self.tok_embd_a.save_gguf(gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD,  suffix=".weight.lora_a"))
+        self.tok_embd_b.save_gguf(gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD,  suffix=".weight.lora_b"))
+        self.norm_a.save_gguf    (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.OUTPUT_NORM, suffix=".weight.lora_a"))
+        self.norm_b.save_gguf    (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.OUTPUT_NORM, suffix=".weight.lora_b"))
+        self.output_a.save_gguf  (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.OUTPUT,      suffix=".weight.lora_a"))
+        self.output_b.save_gguf  (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.OUTPUT,      suffix=".weight.lora_b"))
+
+        for layer in self.layers:
+            layer.save_gguf(gguf_writer)
+
+class LoraCheckpoint:
+    def __init__(self, n_ff = None):
+        self.model = LoraModel(n_ff = n_ff)
+        self.opt_ctx = OptimizationContext()
+
+    def load(self, data, offset):
+        magic   = bytes(reversed(data[offset:offset + 4])); offset += 4
+        if magic != b'ggcl':
+            raise ValueError(f"File header magic indicates, that this is no finetune-lora checkpoint file. Expected 'ggcl', Got '{str(magic)}'")
+
+        self.version = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+        if self.version != 0:
+            raise ValueError('Invalid version of checkpoint file')
+
+        self.train_its     = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+        self.train_samples = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+        self.train_tokens  = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+
+        offset = self.model.load(data, offset)
+        offset = self.opt_ctx.load(data, offset)
+
+        return offset
+
+    def save_gguf(self, gguf_writer):
+        gguf_writer.add_file_type(gguf.GGMLQuantizationType.F32)
+        gguf_writer.add_layer_norm_rms_eps(1e-5)
+        gguf_writer.add_uint32(LLM_KV_TRAINING_FILE_VERSION,    0)
+        gguf_writer.add_string(LLM_KV_TRAINING_TYPE,            LLM_KV_TRAINING_TYPE_FINETUNE_LORA)
+        gguf_writer.add_uint32(LLM_KV_TRAINING_ITERATION_COUNT, self.train_its)
+        gguf_writer.add_uint32(LLM_KV_TRAINING_SAMPLE_COUNT,    self.train_samples)
+        gguf_writer.add_uint32(LLM_KV_TRAINING_TOKEN_COUNT,     self.train_tokens)
+        self.model.save_gguf(gguf_writer)
+        self.opt_ctx.save_gguf(gguf_writer)
+
+def handle_args():
+    parser = argparse.ArgumentParser(description = 'Convert finetune checkpoints to GGUF')
+    parser.add_argument('--input',  '-i', type = Path, help = 'Input finetune checkpoint filename', required=True)
+    parser.add_argument('--output', '-o', type = Path, help = 'Output GGUF filename', required=True)
+    parser.add_argument('--ff', type = int, help = "Feedforward size, if not provided compute from n_mult. Provide this if you get 'ValueError: Tensor.load: Expected number of elements does not match what is read from file'", required=False)
+    return parser.parse_args()
+
+def main():
+    cfg = handle_args()
+    print(cfg)
+    data = np.memmap(cfg.input, mode = 'r')
+    chk = LoraCheckpoint(n_ff = cfg.ff)
+    offset = 0
+    offset = chk.load(data, offset)
+    # we should have read all available data
+    assert(offset == len(data))
+
+    gguf_writer = gguf.GGUFWriter(cfg.output, gguf.MODEL_ARCH_NAMES[gguf.MODEL_ARCH.LLAMA], use_temp_file = False)
+    chk.save_gguf(gguf_writer)
+    print("    gguf: write header")
+    gguf_writer.write_header_to_file()
+    print("    gguf: write metadata")
+    gguf_writer.write_kv_data_to_file()
+    print("    gguf: write tensors")
+    gguf_writer.write_tensors_to_file()
+    gguf_writer.close()
+
+if __name__ == '__main__':
+    main()

+ 1935 - 0
examples/finetune/finetune.cpp

@@ -0,0 +1,1935 @@
+#include "ggml.h"
+#include "ggml-alloc.h"
+#include "llama.h"
+#include "common.h"
+#include "train.h"
+#include <unordered_map>
+#include <vector>
+#include <cassert>
+#include <climits>
+#include <cstring>
+#include <cstdarg>
+#include <ctime>
+#include <random>
+#include <stdexcept>
+#include <algorithm>
+#include <string>
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+static const size_t tensor_alignment = 32;
+
+struct my_llama_hparams {
+    uint32_t n_vocab    = 32000;
+    uint32_t n_ctx      = 512;
+    uint32_t n_embd     = 4096;
+    uint32_t n_ff       = 11008;
+    uint32_t n_head     = 32;
+    uint32_t n_head_kv  = 32;
+    uint32_t n_layer    = 32;
+
+    // float f_norm_eps     = 1e-5f; // falcon
+    float f_norm_rms_eps = 1e-5f; // llama
+
+    float rope_freq_base  = 10000.0f;
+    float rope_freq_scale = 1.0f;
+
+    uint32_t n_gqa() const {
+        return n_head/n_head_kv;
+    }
+
+    uint32_t n_embd_head() const {
+        return n_embd/n_head;
+    }
+
+    uint32_t n_embd_gqa() const {
+        return n_embd/n_gqa();
+    }
+
+    bool operator!=(const my_llama_hparams& other) const {
+        return memcmp(this, &other, sizeof(other));
+    }
+};
+
+struct my_llama_layer {
+    // normalization
+    struct ggml_tensor * attention_norm;
+
+    // attention
+    struct ggml_tensor * wq;
+    struct ggml_tensor * wk;
+    struct ggml_tensor * wv;
+    struct ggml_tensor * wo;
+
+    // normalization
+    struct ggml_tensor * ffn_norm;
+
+    // ff
+    struct ggml_tensor * w1;
+    struct ggml_tensor * w2;
+    struct ggml_tensor * w3;
+};
+
+struct my_llama_model {
+    struct my_llama_hparams hparams;
+
+    struct ggml_tensor * tok_embeddings;
+
+    struct ggml_tensor * norm;
+    struct ggml_tensor * output;
+
+    std::vector<my_llama_layer> layers;
+};
+
+struct my_llama_lora_hparams {
+    uint32_t lora_r = 1;
+    uint32_t lora_alpha = 1;
+    uint32_t n_rank_attention_norm = 1;
+    uint32_t n_rank_wq = 4;
+    uint32_t n_rank_wk = 4;
+    uint32_t n_rank_wv = 4;
+    uint32_t n_rank_wo = 4;
+    uint32_t n_rank_ffn_norm = 1;
+    uint32_t n_rank_w1 = 4;
+    uint32_t n_rank_w2 = 4;
+    uint32_t n_rank_w3 = 4;
+    uint32_t n_rank_tok_embeddings = 4;
+    uint32_t n_rank_norm = 1;
+    uint32_t n_rank_output = 4;
+
+    bool operator!=(const my_llama_lora_hparams& other) const {
+        return memcmp(this, &other, sizeof(other));
+    }
+};
+
+struct my_llama_lora_layer {
+    // normalization
+    struct ggml_tensor * attention_norm_a;
+    struct ggml_tensor * attention_norm_b;
+
+    // attention
+    struct ggml_tensor * wq_a;
+    struct ggml_tensor * wq_b;
+    struct ggml_tensor * wk_a;
+    struct ggml_tensor * wk_b;
+    struct ggml_tensor * wv_a;
+    struct ggml_tensor * wv_b;
+    struct ggml_tensor * wo_a;
+    struct ggml_tensor * wo_b;
+
+    // normalization
+    struct ggml_tensor * ffn_norm_a;
+    struct ggml_tensor * ffn_norm_b;
+
+    // ff
+    struct ggml_tensor * w1_a;
+    struct ggml_tensor * w1_b;
+    struct ggml_tensor * w2_a;
+    struct ggml_tensor * w2_b;
+    struct ggml_tensor * w3_a;
+    struct ggml_tensor * w3_b;
+};
+
+struct my_llama_lora {
+    struct ggml_context * ctx = NULL;
+    std::vector<uint8_t> data;
+
+    my_llama_lora_hparams hparams;
+
+    struct ggml_tensor * tok_embeddings_a;
+    struct ggml_tensor * tok_embeddings_b;
+
+    struct ggml_tensor * norm_a;
+    struct ggml_tensor * norm_b;
+    struct ggml_tensor * output_a;
+    struct ggml_tensor * output_b;
+
+    std::vector<my_llama_lora_layer> layers;
+};
+
+// gguf constants
+static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA   = "finetune_lora";
+static const char * LLM_KV_TRAINING_TYPE                 = "training.type";
+
+static const char * LLM_KV_TRAINING_LORA_RANK_TOKEN_EMBD  = "training.lora.rank.token_embd";
+static const char * LLM_KV_TRAINING_LORA_RANK_OUTPUT_NORM = "training.lora.rank.output_norm";
+static const char * LLM_KV_TRAINING_LORA_RANK_OUTPUT      = "training.lora.rank.output";
+static const char * LLM_KV_TRAINING_LORA_RANK_ATTN_NORM   = "training.lora.rank.attn_norm";
+static const char * LLM_KV_TRAINING_LORA_RANK_ATTN_Q      = "training.lora.rank.attn_q";
+static const char * LLM_KV_TRAINING_LORA_RANK_ATTN_K      = "training.lora.rank.attn_k";
+static const char * LLM_KV_TRAINING_LORA_RANK_ATTN_V      = "training.lora.rank.attn_v";
+static const char * LLM_KV_TRAINING_LORA_RANK_ATTN_OUT    = "training.lora.rank.attn_output";
+static const char * LLM_KV_TRAINING_LORA_RANK_FFN_NORM    = "training.lora.rank.ffn_norm";
+static const char * LLM_KV_TRAINING_LORA_RANK_FFN_GATE    = "training.lora.rank.ffn_gate";
+static const char * LLM_KV_TRAINING_LORA_RANK_FFN_DOWN    = "training.lora.rank.ffn_down";
+static const char * LLM_KV_TRAINING_LORA_RANK_FFN_UP      = "training.lora.rank.ffn_up";
+
+// gguf constants (sync with gguf.py)
+
+static const char * LLM_KV_GENERAL_ARCHITECTURE        = "general.architecture";
+static const char * LLM_KV_GENERAL_FILE_TYPE           = "general.file_type";
+
+static const char * LLM_KV_CONTEXT_LENGTH              = "%s.context_length";
+static const char * LLM_KV_EMBEDDING_LENGTH            = "%s.embedding_length";
+static const char * LLM_KV_BLOCK_COUNT                 = "%s.block_count";
+static const char * LLM_KV_FEED_FORWARD_LENGTH         = "%s.feed_forward_length";
+static const char * LLM_KV_ATTENTION_HEAD_COUNT        = "%s.attention.head_count";
+static const char * LLM_KV_ATTENTION_HEAD_COUNT_KV     = "%s.attention.head_count_kv";
+static const char * LLM_KV_ATTENTION_LAYERNORM_RMS_EPS = "%s.attention.layer_norm_rms_epsilon";
+static const char * LLM_KV_ROPE_DIMENSION_COUNT        = "%s.rope.dimension_count";
+static const char * LLM_KV_ROPE_FREQ_BASE              = "%s.rope.freq_base"; // TODO load in llama.cpp
+static const char * LLM_KV_ROPE_SCALE_LINEAR           = "%s.rope.scale_linear";
+
+static const char * LLM_TENSOR_TOKEN_EMBD    = "token_embd";
+static const char * LLM_TENSOR_OUTPUT_NORM   = "output_norm";
+static const char * LLM_TENSOR_OUTPUT        = "output";
+static const char * LLM_TENSOR_ATTN_NORM     = "blk.%d.attn_norm";
+static const char * LLM_TENSOR_ATTN_Q        = "blk.%d.attn_q";
+static const char * LLM_TENSOR_ATTN_K        = "blk.%d.attn_k";
+static const char * LLM_TENSOR_ATTN_V        = "blk.%d.attn_v";
+static const char * LLM_TENSOR_ATTN_OUT      = "blk.%d.attn_output";
+static const char * LLM_TENSOR_FFN_NORM      = "blk.%d.ffn_norm";
+static const char * LLM_TENSOR_FFN_GATE      = "blk.%d.ffn_gate";
+static const char * LLM_TENSOR_FFN_DOWN      = "blk.%d.ffn_down";
+static const char * LLM_TENSOR_FFN_UP        = "blk.%d.ffn_up";
+
+static void print_params(struct my_llama_hparams * params) {
+    printf("%s: n_vocab:   %u\n", __func__, params->n_vocab);
+    printf("%s: n_ctx:     %u\n", __func__, params->n_ctx);
+    printf("%s: n_embd:    %u\n", __func__, params->n_embd);
+    printf("%s: n_ff:      %u\n", __func__, params->n_ff);
+    printf("%s: n_head:    %u\n", __func__, params->n_head);
+    printf("%s: n_head_kv: %u\n", __func__, params->n_head_kv);
+    printf("%s: n_layer:   %u\n", __func__, params->n_layer);
+    printf("%s: norm_rms_eps          : %f\n", __func__, params->f_norm_rms_eps);
+    printf("%s: rope_freq_base        : %f\n", __func__, params->rope_freq_base);
+    printf("%s: rope_freq_scale       : %f\n", __func__, params->rope_freq_scale);
+}
+
+static void print_lora_params(struct my_llama_lora_hparams * params) {
+    printf("%s: n_rank_attention_norm : %u\n", __func__, params->n_rank_attention_norm);
+    printf("%s: n_rank_wq             : %u\n", __func__, params->n_rank_wq);
+    printf("%s: n_rank_wk             : %u\n", __func__, params->n_rank_wk);
+    printf("%s: n_rank_wv             : %u\n", __func__, params->n_rank_wv);
+    printf("%s: n_rank_wo             : %u\n", __func__, params->n_rank_wo);
+    printf("%s: n_rank_ffn_norm       : %u\n", __func__, params->n_rank_ffn_norm);
+    printf("%s: n_rank_w1             : %u\n", __func__, params->n_rank_w1);
+    printf("%s: n_rank_w2             : %u\n", __func__, params->n_rank_w2);
+    printf("%s: n_rank_w3             : %u\n", __func__, params->n_rank_w3);
+    printf("%s: n_rank_tok_embeddings : %u\n", __func__, params->n_rank_tok_embeddings);
+    printf("%s: n_rank_norm           : %u\n", __func__, params->n_rank_norm);
+    printf("%s: n_rank_output         : %u\n", __func__, params->n_rank_output);
+}
+
+#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
+{ \
+    const std::string skey(key); \
+    const int kid = gguf_find_key(ctx, skey.c_str()); \
+    if (kid >= 0) { \
+        enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
+        if (ktype != (type)) { \
+            die_fmt("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype)); \
+        } \
+        (dst) = func(ctx, kid); \
+    } else if (req) { \
+        die_fmt("key not found in model: %s", skey.c_str()); \
+    } \
+}
+
+static void load_model_hparams_gguf(struct gguf_context * ctx, struct my_llama_hparams * hparams, const char * expected_arch) {
+    std::string arch;
+
+    GGUF_GET_KEY(ctx, arch, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_GENERAL_ARCHITECTURE);
+    if (expected_arch != NULL) {
+        if (arch != expected_arch) {
+            printf("%s: arch=%s expected_arch=%s\n", __func__, arch.c_str(), expected_arch);
+        }
+        GGML_ASSERT(arch == expected_arch);
+    }
+
+    std::vector<char> keybuf;
+    keybuf.resize(512);
+    auto kv = [&arch, &keybuf](const char * key) -> const char * {
+        snprintf(keybuf.data(), keybuf.size(), key, arch.c_str());
+        return keybuf.data();
+    };
+
+    GGUF_GET_KEY(ctx, hparams->n_embd,         gguf_get_val_u32, GGUF_TYPE_UINT32,  true, kv(LLM_KV_EMBEDDING_LENGTH));
+    GGUF_GET_KEY(ctx, hparams->n_ctx,          gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_CONTEXT_LENGTH));
+    GGUF_GET_KEY(ctx, hparams->n_ff,           gguf_get_val_u32, GGUF_TYPE_UINT32,  true, kv(LLM_KV_FEED_FORWARD_LENGTH));
+    GGUF_GET_KEY(ctx, hparams->n_head,         gguf_get_val_u32, GGUF_TYPE_UINT32,  true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
+    GGUF_GET_KEY(ctx, hparams->n_layer,        gguf_get_val_u32, GGUF_TYPE_UINT32,  true, kv(LLM_KV_BLOCK_COUNT));
+
+    // n_head_kv is optional, default to n_head
+    hparams->n_head_kv = hparams->n_head;
+    GGUF_GET_KEY(ctx, hparams->n_head_kv,      gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
+
+    float rope_freq_scale = 1.0f;
+    GGUF_GET_KEY(ctx, hparams->f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
+    GGUF_GET_KEY(ctx, hparams->rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
+    GGUF_GET_KEY(ctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
+    if (rope_freq_scale != 1.0f) {
+        hparams->rope_freq_scale = 1.0f / rope_freq_scale;
+    }
+}
+
+static void init_model(struct llama_model * input, struct my_llama_model * model, const char * fn_model, uint32_t n_ctx) {
+    auto & hparams = model->hparams;
+
+    std::vector<char> tn_buf;
+    tn_buf.resize(GGML_MAX_NAME);
+    auto tn = [&tn_buf](const char * key) -> const char * {
+        snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", key);
+        return tn_buf.data();
+    };
+    auto tni = [&tn_buf](const char * key, int bid) -> const char * {
+        snprintf(tn_buf.data(), tn_buf.size(), key, bid);
+        std::string s = tn_buf.data();
+        snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", s.c_str());
+        return tn_buf.data();
+    };
+
+
+    // get parameters directly from gguf file
+    {
+        struct gguf_init_params params = {
+            /*.no_alloc = */ false,
+            /*.ctx      = */ NULL,
+        };
+        struct gguf_context * mctx = gguf_init_from_file(fn_model, params);
+
+        load_model_hparams_gguf(mctx, &hparams, "llama");
+
+        gguf_free(mctx);
+    }
+    hparams.n_vocab = llama_model_n_vocab(input);
+    hparams.n_ctx = n_ctx;
+
+    // get tensors from llama_model (possibly mmapped)
+    model->tok_embeddings = llama_get_model_tensor(input, tn(LLM_TENSOR_TOKEN_EMBD));
+    model->norm           = llama_get_model_tensor(input, tn(LLM_TENSOR_OUTPUT_NORM));
+    model->output         = llama_get_model_tensor(input, tn(LLM_TENSOR_OUTPUT));
+
+    assert_shape_2d(model->tok_embeddings, hparams.n_embd, hparams.n_vocab);
+    assert_shape_1d(model->norm,           hparams.n_embd);
+    assert_shape_2d(model->output,         hparams.n_embd, hparams.n_vocab);
+
+    model->layers.resize(hparams.n_layer);
+    for (uint32_t i = 0; i < hparams.n_layer; ++i) {
+        auto & layer = model->layers[i];
+
+        layer.attention_norm = llama_get_model_tensor(input, tni(LLM_TENSOR_ATTN_NORM, i));
+        layer.wq             = llama_get_model_tensor(input, tni(LLM_TENSOR_ATTN_Q, i));
+        layer.wk             = llama_get_model_tensor(input, tni(LLM_TENSOR_ATTN_K, i));
+        layer.wv             = llama_get_model_tensor(input, tni(LLM_TENSOR_ATTN_V, i));
+        layer.wo             = llama_get_model_tensor(input, tni(LLM_TENSOR_ATTN_OUT, i));
+        layer.ffn_norm       = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_NORM, i));
+        layer.w1             = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_GATE, i));
+        layer.w2             = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_DOWN, i));
+        layer.w3             = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_UP, i));
+
+        assert_shape_1d(layer.attention_norm, hparams.n_embd);
+        assert_shape_2d(layer.wq,             hparams.n_embd, hparams.n_embd);
+        assert_shape_2d(layer.wk,             hparams.n_embd, hparams.n_embd);
+        assert_shape_2d(layer.wv,             hparams.n_embd, hparams.n_embd);
+        assert_shape_2d(layer.wo,             hparams.n_embd, hparams.n_embd);
+        assert_shape_1d(layer.ffn_norm,       hparams.n_embd);
+        assert_shape_2d(layer.w1,             hparams.n_embd, hparams.n_ff);
+        assert_shape_2d(layer.w2,             hparams.n_ff,   hparams.n_embd);
+        assert_shape_2d(layer.w3,             hparams.n_embd, hparams.n_ff);
+    }
+}
+
+static void set_param_lora(struct my_llama_lora * lora) {
+    const uint32_t n_layer = lora->layers.size();
+
+    struct ggml_context* ctx = lora->ctx;
+
+    ggml_set_param(ctx, lora->tok_embeddings_a);
+    ggml_set_param(ctx, lora->tok_embeddings_b);
+    ggml_set_param(ctx, lora->norm_a);
+    ggml_set_param(ctx, lora->norm_b);
+    ggml_set_param(ctx, lora->output_a);
+    ggml_set_param(ctx, lora->output_b);
+
+    for (uint32_t i = 0; i < n_layer; ++i) {
+        auto & layer = lora->layers[i];
+
+        ggml_set_param(ctx, layer.attention_norm_a);
+        ggml_set_param(ctx, layer.attention_norm_b);
+        ggml_set_param(ctx, layer.wq_a);
+        ggml_set_param(ctx, layer.wq_b);
+        ggml_set_param(ctx, layer.wk_a);
+        ggml_set_param(ctx, layer.wk_b);
+        ggml_set_param(ctx, layer.wv_a);
+        ggml_set_param(ctx, layer.wv_b);
+        ggml_set_param(ctx, layer.wo_a);
+        ggml_set_param(ctx, layer.wo_b);
+        ggml_set_param(ctx, layer.ffn_norm_a);
+        ggml_set_param(ctx, layer.ffn_norm_b);
+        ggml_set_param(ctx, layer.w1_a);
+        ggml_set_param(ctx, layer.w1_b);
+        ggml_set_param(ctx, layer.w2_a);
+        ggml_set_param(ctx, layer.w2_b);
+        ggml_set_param(ctx, layer.w3_a);
+        ggml_set_param(ctx, layer.w3_b);
+    }
+}
+
+static void alloc_lora(struct ggml_allocr * alloc, struct my_llama_lora * lora) {
+    ggml_allocr_alloc(alloc, lora->tok_embeddings_a);
+    ggml_allocr_alloc(alloc, lora->tok_embeddings_b);
+    ggml_allocr_alloc(alloc, lora->norm_a);
+    ggml_allocr_alloc(alloc, lora->norm_b);
+    ggml_allocr_alloc(alloc, lora->output_a);
+    ggml_allocr_alloc(alloc, lora->output_b);
+    for (uint32_t i = 0; i < lora->layers.size(); ++i) {
+        auto & layer = lora->layers[i];
+        ggml_allocr_alloc(alloc, layer.attention_norm_a);
+        ggml_allocr_alloc(alloc, layer.attention_norm_b);
+        ggml_allocr_alloc(alloc, layer.wq_a);
+        ggml_allocr_alloc(alloc, layer.wq_b);
+        ggml_allocr_alloc(alloc, layer.wk_a);
+        ggml_allocr_alloc(alloc, layer.wk_b);
+        ggml_allocr_alloc(alloc, layer.wv_a);
+        ggml_allocr_alloc(alloc, layer.wv_b);
+        ggml_allocr_alloc(alloc, layer.wo_a);
+        ggml_allocr_alloc(alloc, layer.wo_b);
+        ggml_allocr_alloc(alloc, layer.ffn_norm_a);
+        ggml_allocr_alloc(alloc, layer.ffn_norm_b);
+        ggml_allocr_alloc(alloc, layer.w1_a);
+        ggml_allocr_alloc(alloc, layer.w1_b);
+        ggml_allocr_alloc(alloc, layer.w2_a);
+        ggml_allocr_alloc(alloc, layer.w2_b);
+        ggml_allocr_alloc(alloc, layer.w3_a);
+        ggml_allocr_alloc(alloc, layer.w3_b);
+    }
+    ggml_allocr_alloc(alloc, lora->tok_embeddings_a->grad);
+    ggml_allocr_alloc(alloc, lora->tok_embeddings_b->grad);
+    ggml_allocr_alloc(alloc, lora->norm_a->grad);
+    ggml_allocr_alloc(alloc, lora->norm_b->grad);
+    ggml_allocr_alloc(alloc, lora->output_a->grad);
+    ggml_allocr_alloc(alloc, lora->output_b->grad);
+    for (uint32_t i = 0; i < lora->layers.size(); ++i) {
+        auto & layer = lora->layers[i];
+        ggml_allocr_alloc(alloc, layer.attention_norm_a->grad);
+        ggml_allocr_alloc(alloc, layer.attention_norm_b->grad);
+        ggml_allocr_alloc(alloc, layer.wq_a->grad);
+        ggml_allocr_alloc(alloc, layer.wq_b->grad);
+        ggml_allocr_alloc(alloc, layer.wk_a->grad);
+        ggml_allocr_alloc(alloc, layer.wk_b->grad);
+        ggml_allocr_alloc(alloc, layer.wv_a->grad);
+        ggml_allocr_alloc(alloc, layer.wv_b->grad);
+        ggml_allocr_alloc(alloc, layer.wo_a->grad);
+        ggml_allocr_alloc(alloc, layer.wo_b->grad);
+        ggml_allocr_alloc(alloc, layer.ffn_norm_a->grad);
+        ggml_allocr_alloc(alloc, layer.ffn_norm_b->grad);
+        ggml_allocr_alloc(alloc, layer.w1_a->grad);
+        ggml_allocr_alloc(alloc, layer.w1_b->grad);
+        ggml_allocr_alloc(alloc, layer.w2_a->grad);
+        ggml_allocr_alloc(alloc, layer.w2_b->grad);
+        ggml_allocr_alloc(alloc, layer.w3_a->grad);
+        ggml_allocr_alloc(alloc, layer.w3_b->grad);
+    }
+}
+
+static void init_lora(const struct my_llama_model * model, struct my_llama_lora * lora) {
+    const auto & lparams = lora->hparams;
+
+    const uint32_t n_embd     = model->hparams.n_embd;
+    const uint32_t n_embd_gqa = model->hparams.n_embd_gqa();
+    const uint32_t n_layer    = model->hparams.n_layer;
+    const uint32_t n_vocab    = model->hparams.n_vocab;
+    const uint32_t n_ff       = model->hparams.n_ff;
+
+    std::vector<char> tn_buf;
+    tn_buf.resize(GGML_MAX_NAME);
+    auto tn = [&tn_buf](const char * key, const char * suffix) -> const char * {
+        snprintf(tn_buf.data(), tn_buf.size(), "%s%s", key, suffix);
+        return tn_buf.data();
+    };
+    auto tni = [&tn_buf](const char * key, const char * suffix, int bid) -> const char * {
+        snprintf(tn_buf.data(), tn_buf.size(), key, bid);
+        std::string s = tn_buf.data();
+        snprintf(tn_buf.data(), tn_buf.size(), "%s%s", s.c_str(), suffix);
+        return tn_buf.data();
+    };
+
+    // context for lora tensors without their data
+    struct ggml_init_params ctx_lora_params;
+    ctx_lora_params.mem_size   = ggml_tensor_overhead()*2*(6 + n_layer*18);
+    ctx_lora_params.mem_buffer = NULL;
+    ctx_lora_params.no_alloc   = true;
+
+    struct ggml_context * ctx = ggml_init(ctx_lora_params);
+    lora->ctx = ctx;
+
+    lora->tok_embeddings_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_tok_embeddings, n_embd);
+    lora->tok_embeddings_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_tok_embeddings, n_vocab);
+    lora->norm_a           = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_norm, n_embd);
+    lora->norm_b           = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_norm, 1);
+    lora->output_a         = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_output, n_embd);
+    lora->output_b         = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_output, n_vocab);
+
+    ggml_set_name(lora->tok_embeddings_a, tn(LLM_TENSOR_TOKEN_EMBD,  ".weight.lora_a"));
+    ggml_set_name(lora->tok_embeddings_b, tn(LLM_TENSOR_TOKEN_EMBD,  ".weight.lora_b"));
+    ggml_set_name(lora->norm_a,           tn(LLM_TENSOR_OUTPUT_NORM, ".weight.lora_a"));
+    ggml_set_name(lora->norm_b,           tn(LLM_TENSOR_OUTPUT_NORM, ".weight.lora_b"));
+    ggml_set_name(lora->output_a,         tn(LLM_TENSOR_OUTPUT,      ".weight.lora_a"));
+    ggml_set_name(lora->output_b,         tn(LLM_TENSOR_OUTPUT,      ".weight.lora_b"));
+
+    lora->layers.resize(n_layer);
+    for (uint32_t i = 0; i < n_layer; ++i) {
+        auto & layer = lora->layers[i];
+
+        layer.attention_norm_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_attention_norm, n_embd);
+        layer.attention_norm_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_attention_norm, 1);
+
+        layer.wq_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wq, n_embd);
+        layer.wq_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wq, n_embd);
+        layer.wk_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wk, n_embd);
+        layer.wk_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wk, n_embd_gqa);
+        layer.wv_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wv, n_embd);
+        layer.wv_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wv, n_embd_gqa);
+        layer.wo_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wo, n_embd);
+        layer.wo_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wo, n_embd);
+
+        layer.ffn_norm_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_ffn_norm, n_embd);
+        layer.ffn_norm_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_ffn_norm, 1);
+
+        layer.w1_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_w1, n_embd);
+        layer.w1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_w1, n_ff);
+        layer.w2_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_w2, n_ff);
+        layer.w2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_w2, n_embd);
+        layer.w3_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_w3, n_embd);
+        layer.w3_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_w3, n_ff);
+
+        ggml_set_name(layer.attention_norm_a, tni(LLM_TENSOR_ATTN_NORM, ".weight.lora_a", i));
+        ggml_set_name(layer.attention_norm_b, tni(LLM_TENSOR_ATTN_NORM, ".weight.lora_b", i));
+        ggml_set_name(layer.wq_a,             tni(LLM_TENSOR_ATTN_Q,    ".weight.lora_a", i));
+        ggml_set_name(layer.wq_b,             tni(LLM_TENSOR_ATTN_Q,    ".weight.lora_b", i));
+        ggml_set_name(layer.wk_a,             tni(LLM_TENSOR_ATTN_K,    ".weight.lora_a", i));
+        ggml_set_name(layer.wk_b,             tni(LLM_TENSOR_ATTN_K,    ".weight.lora_b", i));
+        ggml_set_name(layer.wv_a,             tni(LLM_TENSOR_ATTN_V,    ".weight.lora_a", i));
+        ggml_set_name(layer.wv_b,             tni(LLM_TENSOR_ATTN_V,    ".weight.lora_b", i));
+        ggml_set_name(layer.wo_a,             tni(LLM_TENSOR_ATTN_OUT,  ".weight.lora_a", i));
+        ggml_set_name(layer.wo_b,             tni(LLM_TENSOR_ATTN_OUT,  ".weight.lora_b", i));
+        ggml_set_name(layer.ffn_norm_a,       tni(LLM_TENSOR_FFN_NORM,  ".weight.lora_a", i));
+        ggml_set_name(layer.ffn_norm_b,       tni(LLM_TENSOR_FFN_NORM,  ".weight.lora_b", i));
+        ggml_set_name(layer.w1_a,             tni(LLM_TENSOR_FFN_GATE,  ".weight.lora_a", i));
+        ggml_set_name(layer.w1_b,             tni(LLM_TENSOR_FFN_GATE,  ".weight.lora_b", i));
+        ggml_set_name(layer.w2_a,             tni(LLM_TENSOR_FFN_DOWN,  ".weight.lora_a", i));
+        ggml_set_name(layer.w2_b,             tni(LLM_TENSOR_FFN_DOWN,  ".weight.lora_b", i));
+        ggml_set_name(layer.w3_a,             tni(LLM_TENSOR_FFN_UP,    ".weight.lora_a", i));
+        ggml_set_name(layer.w3_b,             tni(LLM_TENSOR_FFN_UP,    ".weight.lora_b", i));
+    }
+
+    set_param_lora(lora);
+
+    // measure data size
+    struct ggml_allocr * alloc = NULL;
+    alloc = ggml_allocr_new_measure(tensor_alignment);
+    alloc_lora(alloc, lora);
+
+    // allocate data
+    lora->data.resize(ggml_allocr_max_size(alloc) + tensor_alignment);
+    ggml_allocr_free(alloc);
+    alloc = ggml_allocr_new(lora->data.data(), lora->data.size(), tensor_alignment);
+    alloc_lora(alloc, lora);
+    ggml_allocr_free(alloc);
+}
+
+static void randomize_lora(struct my_llama_lora * lora, int seed, float mean, float std, float min, float max) {
+    const uint32_t n_layer = lora->layers.size();
+
+    struct random_normal_distribution * rnd = init_random_normal_distribution(seed, mean, std, min, max);
+
+    randomize_tensor_normal(lora->tok_embeddings_a, rnd);
+    randomize_tensor_normal(lora->tok_embeddings_b, rnd);
+    randomize_tensor_normal(lora->norm_a,           rnd);
+    randomize_tensor_normal(lora->norm_b,           rnd);
+    randomize_tensor_normal(lora->output_a,         rnd);
+    randomize_tensor_normal(lora->output_b,         rnd);
+
+    for (uint32_t i = 0; i < n_layer; ++i) {
+        auto & layer = lora->layers[i];
+        randomize_tensor_normal(layer.attention_norm_a, rnd);
+        randomize_tensor_normal(layer.attention_norm_b, rnd);
+
+        randomize_tensor_normal(layer.wq_a, rnd);
+        randomize_tensor_normal(layer.wq_b, rnd);
+        randomize_tensor_normal(layer.wk_a, rnd);
+        randomize_tensor_normal(layer.wk_b, rnd);
+        randomize_tensor_normal(layer.wv_a, rnd);
+        randomize_tensor_normal(layer.wv_b, rnd);
+        randomize_tensor_normal(layer.wo_a, rnd);
+        randomize_tensor_normal(layer.wo_b, rnd);
+
+        randomize_tensor_normal(layer.ffn_norm_a, rnd);
+        randomize_tensor_normal(layer.ffn_norm_b, rnd);
+
+        randomize_tensor_normal(layer.w1_a, rnd);
+        randomize_tensor_normal(layer.w1_b, rnd);
+        randomize_tensor_normal(layer.w2_a, rnd);
+        randomize_tensor_normal(layer.w2_b, rnd);
+        randomize_tensor_normal(layer.w3_a, rnd);
+        randomize_tensor_normal(layer.w3_b, rnd);
+    }
+
+    free_random_normal_distribution(rnd);
+}
+
+static struct ggml_tensor * llama_build_lora_finetune_graphs(
+        struct my_llama_model * model,
+        struct my_llama_lora  * lora,
+        struct ggml_allocr    * alloc,
+        struct ggml_context   * ctx,
+        struct ggml_cgraph    * gf,
+        struct ggml_cgraph    * gb,
+        struct ggml_cgraph    * gb_tmp,
+        struct ggml_tensor  * * logits,
+        struct ggml_tensor    * tokens_input,
+        struct ggml_tensor    * targets,
+        const  int              n_tokens,
+        const  int              n_batch,
+        const  bool             enable_flash_attn,
+        const  bool             enable_checkpointing) {
+
+    ggml_set_scratch(ctx, { 0, 0, nullptr, });
+    const int n_past = 0;
+    const int N = n_tokens;
+    const auto & hparams  = model->hparams;
+    const int n_ctx       = hparams.n_ctx;
+    const int n_vocab     = hparams.n_vocab;
+    const int n_embd      = hparams.n_embd;
+    const int n_layer     = hparams.n_layer;
+    const int n_head      = hparams.n_head;
+    const int n_head_kv   = hparams.n_head_kv;
+    const int n_ff        = hparams.n_ff;
+    const int n_rot       = hparams.n_embd_head();
+    const int n_embd_head = hparams.n_embd_head();
+    const int n_embd_gqa  = hparams.n_embd_gqa();
+    const float rms_norm_eps    = hparams.f_norm_rms_eps;
+    const float rope_freq_base  = hparams.rope_freq_base;
+    const float rope_freq_scale = hparams.rope_freq_scale;
+
+    GGML_ASSERT((size_t) n_layer == lora->layers.size());
+
+    auto set_name = [](struct ggml_tensor * t, const char * n) {
+        ggml_set_name(t, n);
+        if (t->grad) {
+            ggml_format_name(t->grad, "%s->grad", n);
+        }
+    };
+
+    // KQ_pos - contains the positions
+    struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
+    {
+        int * data = (int *) KQ_pos->data;
+        for (int i = 0; i < N; ++i) {
+            data[i] = n_past + i;
+        }
+    }
+
+    // rope has so much parameters that we make a custom function for it
+    auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
+                (struct ggml_tensor * t) -> struct ggml_tensor * {
+        // not capturing these, to silcence warnings
+        const int rope_mode = 0;
+
+        return ggml_rope_custom(ctx,
+            t, KQ_pos, n_rot, rope_mode, n_ctx,
+            rope_freq_base, rope_freq_scale);
+    };
+
+    set_name(tokens_input, "tokens_input");
+    set_name(targets,      "targets");
+
+    GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
+
+    auto add_to_f32 = [] (struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
+        if (ggml_is_quantized(a->type)) {
+            return ggml_add_cast(ctx, a, b, GGML_TYPE_F32);
+        } else if (a->type == GGML_TYPE_F32) {
+            return ggml_add(ctx, a, b);
+        } else {
+            die_fmt("%s: Finetuning on tensors with type '%s' is not yet supported.\n",
+                __func__, ggml_type_name(a->type));
+        }
+    };
+
+    struct ggml_tensor * tok_embeddings = add_to_f32(ctx, model->tok_embeddings, ggml_mul_mat(ctx, lora->tok_embeddings_a, lora->tok_embeddings_b));
+    struct ggml_tensor * norm           = add_to_f32(ctx, model->norm, ggml_mul_mat(ctx, lora->norm_a, lora->norm_b));
+    struct ggml_tensor * output         = add_to_f32(ctx, model->output, ggml_mul_mat(ctx, lora->output_a, lora->output_b));
+
+    struct ggml_tensor * t00 = ggml_reshape_1d(ctx, tokens_input, N*n_batch);  set_name(t00, "t00"); assert_shape_1d(t00, N*n_batch);
+    struct ggml_tensor * t01 = ggml_get_rows(ctx, tok_embeddings, t00);        set_name(t01, "t01"); assert_shape_2d(t01, n_embd, N*n_batch);
+
+    struct ggml_tensor * cur = t01;
+
+    std::vector<struct ggml_tensor *> checkpoints;
+    if (enable_checkpointing) {
+        checkpoints.push_back(tokens_input);
+        checkpoints.push_back(targets);
+        checkpoints.push_back(t00);
+        checkpoints.push_back(t01);
+    }
+
+    struct ggml_tensor * kv_scale = NULL;
+    if (!enable_flash_attn) {
+        kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head));
+    }
+
+    for (int il = 0; il < n_layer; ++il) {
+        struct my_llama_layer & layer = model->layers[il];
+        struct my_llama_lora_layer & llayer = lora->layers[il];
+
+        struct ggml_tensor * attention_norm = add_to_f32(ctx, layer.attention_norm, ggml_mul_mat(ctx, llayer.attention_norm_a, llayer.attention_norm_b));
+        struct ggml_tensor * ffn_norm = add_to_f32(ctx, layer.ffn_norm, ggml_mul_mat(ctx, llayer.ffn_norm_a, llayer.ffn_norm_b));
+        struct ggml_tensor * wq = add_to_f32(ctx, layer.wq, ggml_mul_mat(ctx, llayer.wq_a, llayer.wq_b));
+        struct ggml_tensor * wk = add_to_f32(ctx, layer.wk, ggml_mul_mat(ctx, llayer.wk_a, llayer.wk_b));
+        struct ggml_tensor * wv = add_to_f32(ctx, layer.wv, ggml_mul_mat(ctx, llayer.wv_a, llayer.wv_b));
+        struct ggml_tensor * wo = add_to_f32(ctx, layer.wo, ggml_mul_mat(ctx, llayer.wo_a, llayer.wo_b));
+        struct ggml_tensor * w1 = add_to_f32(ctx, layer.w1, ggml_mul_mat(ctx, llayer.w1_a, llayer.w1_b));
+        struct ggml_tensor * w2 = add_to_f32(ctx, layer.w2, ggml_mul_mat(ctx, llayer.w2_a, llayer.w2_b));
+        struct ggml_tensor * w3 = add_to_f32(ctx, layer.w3, ggml_mul_mat(ctx, llayer.w3_a, llayer.w3_b));
+
+        struct ggml_tensor * t02 = ggml_rms_norm     (ctx, cur, rms_norm_eps);                       set_name(t02, "t02");     assert_shape_2d(t02, n_embd, N*n_batch);
+        struct ggml_tensor * t03 = ggml_repeat       (ctx, attention_norm, t02);                     set_name(t03, "t03");     assert_shape_2d(t03, n_embd, N*n_batch);
+        struct ggml_tensor * t04 = ggml_mul          (ctx, t03, t02);                                set_name(t04, "t04");     assert_shape_2d(t04, n_embd, N*n_batch);
+        struct ggml_tensor * t05 = ggml_mul_mat      (ctx, wq, t04);                                 set_name(t05, "t05");     assert_shape_2d(t05, n_embd, N*n_batch);
+        struct ggml_tensor * t06 = ggml_reshape_4d   (ctx, t05, n_embd_head, n_head, N, n_batch);    set_name(t06, "t06");     assert_shape_4d(t06, n_embd_head, n_head, N, n_batch);
+        struct ggml_tensor * t07 = rope              (t06);                                          set_name(t07, "t07");     assert_shape_4d(t07, n_embd_head, n_head, N, n_batch);
+        struct ggml_tensor * t08 = ggml_mul_mat      (ctx, wk, t04);                                 set_name(t08, "t08");     assert_shape_2d(t08, n_embd_gqa, N*n_batch);
+        struct ggml_tensor * t09 = ggml_reshape_4d   (ctx, t08, n_embd_head, n_head_kv, N, n_batch); set_name(t09, "t09");     assert_shape_4d(t09, n_embd_head, n_head_kv, N, n_batch);
+        struct ggml_tensor * t10 = rope              (t09);                                          set_name(t10, "t10");     assert_shape_4d(t10, n_embd_head, n_head_kv, N, n_batch);
+
+        struct ggml_tensor * t11;
+        if (ggml_is_quantized(wv->type)) {
+            struct ggml_tensor * t11_1 = ggml_mul_mat  (ctx, wv, t04);                               set_name(t11_1, "t11_1"); assert_shape_2d(t11_1, n_embd_gqa, N*n_batch);
+            struct ggml_tensor * t11_2 = ggml_transpose(ctx, t11_1);                                 set_name(t11_2, "t11_2"); assert_shape_2d(t11_2, N*n_batch, n_embd_gqa);
+                                 t11   = ggml_cont     (ctx, t11_2);                                 set_name(t11, "t11");     assert_shape_2d(t11, N*n_batch, n_embd_gqa);
+        } else {
+                                 t11   = ggml_mul_mat  (ctx, t04, wv);                               set_name(t11, "t11");     assert_shape_2d(t11, N*n_batch, n_embd_gqa);
+        }
+
+        struct ggml_tensor * t12 = ggml_reshape_4d   (ctx, t11, N, n_batch, n_embd_head, n_head_kv); set_name(t12, "t12");     assert_shape_4d(t12, N, n_batch, n_embd_head, n_head_kv);
+        struct ggml_tensor * t13 = ggml_permute      (ctx, t07, 0, 2, 1, 3);                         set_name(t13, "t13");     assert_shape_4d(t13, n_embd_head, N, n_head, n_batch);
+        struct ggml_tensor * t14 = ggml_permute      (ctx, t10, 0, 2, 1, 3);                         set_name(t14, "t14");     assert_shape_4d(t14, n_embd_head, N, n_head_kv, n_batch);
+        struct ggml_tensor * t15 = ggml_permute      (ctx, t12, 0, 3, 1, 2);                         set_name(t15, "t15");     assert_shape_4d(t15, N, n_embd_head, n_head_kv, n_batch);
+        struct ggml_tensor * t16;
+        if (enable_flash_attn) {
+            t16 = ggml_flash_attn(ctx, t13, t14, t15, true);                                         set_name(t16, "t16");     assert_shape_4d(t16, n_embd_head, N, n_head, n_batch);
+        } else {
+            struct ggml_tensor * t16_0 = ggml_mul_mat              (ctx, t14, t13);                  set_name(t16_0, "t16_0"); assert_shape_4d(t16_0, N, N, n_head, n_batch);
+            struct ggml_tensor * t16_1 = ggml_scale_inplace        (ctx, t16_0, kv_scale);           set_name(t16_1, "t16_1"); assert_shape_4d(t16_1, N, N, n_head, n_batch);
+            struct ggml_tensor * t16_2 = ggml_diag_mask_inf_inplace(ctx, t16_1, n_past);             set_name(t16_2, "t16_2"); assert_shape_4d(t16_2, N, N, n_head, n_batch);
+            struct ggml_tensor * t16_3 = ggml_soft_max_inplace     (ctx, t16_2);                     set_name(t16_3, "t16_3"); assert_shape_4d(t16_3, N, N, n_head, n_batch);
+            t16 = ggml_mul_mat(ctx, t15, t16_3);                                                     set_name(t16, "t16");     assert_shape_4d(t16, n_embd_head, N, n_head, n_batch);
+        }
+        struct ggml_tensor * t17 = ggml_permute      (ctx, t16, 0, 2, 1, 3);                         set_name(t17, "t17");     assert_shape_4d(t17, n_embd_head, n_head, N, n_batch);
+        struct ggml_tensor * t18 = ggml_cont         (ctx, t17);                                     set_name(t18, "t18");     assert_shape_4d(t18, n_embd_head, n_head, N, n_batch);
+        struct ggml_tensor * t19 = ggml_reshape_2d   (ctx, t18, n_embd, N*n_batch);                  set_name(t19, "t19");     assert_shape_2d(t19, n_embd, N*n_batch);
+        struct ggml_tensor * t20 = ggml_mul_mat      (ctx, wo, t19);                                 set_name(t20, "t20");     assert_shape_2d(t20, n_embd, N*n_batch);
+        struct ggml_tensor * t21 = ggml_add          (ctx, t20, cur);                                set_name(t21, "t21");     assert_shape_2d(t21, n_embd, N*n_batch);
+        struct ggml_tensor * t22 = ggml_rms_norm     (ctx, t21, rms_norm_eps);                       set_name(t22, "t22");     assert_shape_2d(t22, n_embd, N*n_batch);
+        struct ggml_tensor * t23 = ggml_repeat       (ctx, ffn_norm, t22);                           set_name(t23, "t23");     assert_shape_2d(t23, n_embd, N*n_batch);
+        struct ggml_tensor * t24 = ggml_mul          (ctx, t23, t22);                                set_name(t24, "t24");     assert_shape_2d(t24, n_embd, N*n_batch);
+        struct ggml_tensor * t25 = ggml_mul_mat      (ctx, w3, t24);                                 set_name(t25, "t25");     assert_shape_2d(t25, n_ff, N*n_batch);
+        struct ggml_tensor * t26 = ggml_mul_mat      (ctx, w1, t24);                                 set_name(t26, "t26");     assert_shape_2d(t26, n_ff, N*n_batch);
+        struct ggml_tensor * t27 = ggml_silu         (ctx, t26);                                     set_name(t27, "t27");     assert_shape_2d(t27, n_ff, N*n_batch);
+        struct ggml_tensor * t28 = ggml_mul          (ctx, t27, t25);                                set_name(t28, "t28");     assert_shape_2d(t28, n_ff, N*n_batch);
+        struct ggml_tensor * t29 = ggml_mul_mat      (ctx, w2, t28);                                 set_name(t29, "t29");     assert_shape_2d(t29, n_embd, N*n_batch);
+        struct ggml_tensor * t30 = ggml_add          (ctx, t29, t21);                                set_name(t30, "t30");     assert_shape_2d(t30, n_embd, N*n_batch);
+        cur = t30;
+        if (enable_checkpointing) {
+            checkpoints.push_back(cur);
+        }
+    }
+    struct ggml_tensor * t31   = ggml_rms_norm          (ctx, cur, rms_norm_eps);                    set_name(t31, "t31");     assert_shape_2d(t31, n_embd, N*n_batch);
+    struct ggml_tensor * t32   = ggml_repeat            (ctx, norm, t31);                            set_name(t32, "t32");     assert_shape_2d(t32, n_embd, N*n_batch);
+    struct ggml_tensor * t33   = ggml_mul               (ctx, t32, t31);                             set_name(t33, "t33");     assert_shape_2d(t33, n_embd, N*n_batch);
+    struct ggml_tensor * t34   = ggml_mul_mat           (ctx, output, t33);                          set_name(t34, "t34");     assert_shape_2d(t34, n_vocab, N*n_batch);
+    struct ggml_tensor * t35   = ggml_reshape_3d        (ctx, t34, n_vocab, N, n_batch);             set_name(t35, "t35");     assert_shape_3d(t35, n_vocab, N, n_batch);
+    struct ggml_tensor * t36   = ggml_cross_entropy_loss(ctx, t35, targets);                         set_name(t36, "t36");     assert_shape_1d(t36, 1);
+
+    if (enable_checkpointing) {
+        checkpoints.push_back(t31);
+        checkpoints.push_back(t32);
+        checkpoints.push_back(t33);
+        checkpoints.push_back(t34);
+        checkpoints.push_back(t35);
+        checkpoints.push_back(t36);
+    }
+
+    ggml_build_forward_expand(gf, t36);
+
+    if (enable_checkpointing) {
+        ggml_build_backward_gradient_checkpointing(ctx, gf, gb, gb_tmp, checkpoints.data(), (int) checkpoints.size());
+    } else {
+        *gb = *gf;
+        ggml_build_backward_expand(ctx, gf, gb, true);
+    }
+
+    GGML_ASSERT(alloc != NULL);
+
+    // make sure some tensors are not reallocated by inserting new temporary nodes depending on them
+    int n_leafs_before = gb->n_leafs;
+    int n_nodes_before = gb->n_nodes;
+    struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f);
+    // output tensors
+    ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one));
+    ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
+    // input gradient
+    ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
+    GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
+    ggml_allocr_alloc(alloc, t36->grad);
+
+    // make sure base model tensors data cannot be used in viewable operations
+    ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->tok_embeddings, one));
+    ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->norm, one));
+    ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->output, one));
+    for (int il = 0; il < n_layer; ++il) {
+        struct my_llama_layer & layer = model->layers[il];
+        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.attention_norm, one));
+        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_norm, one));
+        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wq, one));
+        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wk, one));
+        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wv, one));
+        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wo, one));
+        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w1, one));
+        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w2, one));
+        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, one));
+    }
+
+    // allocating checkpoints in one block to reduce memory fragmentation
+    // note: they will be freed in reverse order
+    for (unsigned int i = 0; i < checkpoints.size(); ++i) {
+        if (checkpoints[i]->data == NULL && checkpoints[i]->view_src == NULL) {
+            ggml_allocr_alloc(alloc, checkpoints[i]);
+        }
+    }
+
+    ggml_allocr_alloc_graph(alloc, gb);
+
+    // remove the additional nodes and leafs
+    for (int i = n_leafs_before; i < gb->n_leafs; ++i) {
+        gb->leafs[i] = NULL;
+    }
+    for (int i = n_nodes_before; i < gb->n_nodes; ++i) {
+        gb->nodes[i] = NULL;
+    }
+    gb->n_leafs = n_leafs_before;
+    gb->n_nodes = n_nodes_before;
+
+    *logits = t35;
+    return t36;
+}
+
+static void load_llama_lora_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct my_llama_lora * lora) {
+    // NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
+
+    std::string arch;
+
+    std::vector<char> keybuf;
+    keybuf.resize(512);
+
+    GGUF_GET_KEY(fctx, arch, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_GENERAL_ARCHITECTURE);
+    GGML_ASSERT(arch == "llama");
+
+    uint32_t ftype_u;
+    GGUF_GET_KEY(fctx, ftype_u, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_GENERAL_FILE_TYPE);
+    GGML_ASSERT((enum llama_ftype) ftype_u == LLAMA_FTYPE_ALL_F32);
+
+    struct my_llama_hparams hparams;
+    load_model_hparams_gguf(fctx, &hparams, arch.c_str());
+
+    // parameters that define tensor shapes must match
+    GGML_ASSERT(hparams.n_embd    == model->hparams.n_embd);
+    GGML_ASSERT(hparams.n_ff      == model->hparams.n_ff);
+    GGML_ASSERT(hparams.n_head    == model->hparams.n_head);
+    GGML_ASSERT(hparams.n_head_kv == model->hparams.n_head_kv);
+    GGML_ASSERT(hparams.n_layer   == model->hparams.n_layer);
+
+    GGUF_GET_KEY(fctx, lora->hparams.n_rank_tok_embeddings, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_TOKEN_EMBD);
+    GGUF_GET_KEY(fctx, lora->hparams.n_rank_norm,           gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_OUTPUT_NORM);
+    GGUF_GET_KEY(fctx, lora->hparams.n_rank_output,         gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_OUTPUT);
+    GGUF_GET_KEY(fctx, lora->hparams.n_rank_attention_norm, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_ATTN_NORM);
+    GGUF_GET_KEY(fctx, lora->hparams.n_rank_wq,             gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_ATTN_Q);
+    GGUF_GET_KEY(fctx, lora->hparams.n_rank_wk,             gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_ATTN_K);
+    GGUF_GET_KEY(fctx, lora->hparams.n_rank_wv,             gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_ATTN_V);
+    GGUF_GET_KEY(fctx, lora->hparams.n_rank_wo,             gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_ATTN_OUT);
+    GGUF_GET_KEY(fctx, lora->hparams.n_rank_ffn_norm,       gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_FFN_NORM);
+    GGUF_GET_KEY(fctx, lora->hparams.n_rank_w1,             gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_FFN_GATE);
+    GGUF_GET_KEY(fctx, lora->hparams.n_rank_w2,             gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_FFN_DOWN);
+    GGUF_GET_KEY(fctx, lora->hparams.n_rank_w3,             gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_FFN_UP);
+
+    init_lora(model, lora);
+
+    copy_tensor_by_name(lora->tok_embeddings_a, f_ggml_ctx, ggml_get_name(lora->tok_embeddings_a));
+    copy_tensor_by_name(lora->tok_embeddings_b, f_ggml_ctx, ggml_get_name(lora->tok_embeddings_b));
+    copy_tensor_by_name(lora->norm_a,           f_ggml_ctx, ggml_get_name(lora->norm_a));
+    copy_tensor_by_name(lora->norm_b,           f_ggml_ctx, ggml_get_name(lora->norm_b));
+    copy_tensor_by_name(lora->output_a,         f_ggml_ctx, ggml_get_name(lora->output_a));
+    copy_tensor_by_name(lora->output_b,         f_ggml_ctx, ggml_get_name(lora->output_b));
+
+    for (uint32_t i = 0; i < lora->layers.size(); ++i) {
+        auto & layer = lora->layers[i];
+        copy_tensor_by_name(layer.attention_norm_a, f_ggml_ctx, ggml_get_name(layer.attention_norm_a));
+        copy_tensor_by_name(layer.attention_norm_b, f_ggml_ctx, ggml_get_name(layer.attention_norm_b));
+        copy_tensor_by_name(layer.wq_a,             f_ggml_ctx, ggml_get_name(layer.wq_a));
+        copy_tensor_by_name(layer.wq_b,             f_ggml_ctx, ggml_get_name(layer.wq_b));
+        copy_tensor_by_name(layer.wk_a,             f_ggml_ctx, ggml_get_name(layer.wk_a));
+        copy_tensor_by_name(layer.wk_b,             f_ggml_ctx, ggml_get_name(layer.wk_b));
+        copy_tensor_by_name(layer.wv_a,             f_ggml_ctx, ggml_get_name(layer.wv_a));
+        copy_tensor_by_name(layer.wv_b,             f_ggml_ctx, ggml_get_name(layer.wv_b));
+        copy_tensor_by_name(layer.wo_a,             f_ggml_ctx, ggml_get_name(layer.wo_a));
+        copy_tensor_by_name(layer.wo_b,             f_ggml_ctx, ggml_get_name(layer.wo_b));
+        copy_tensor_by_name(layer.ffn_norm_a,       f_ggml_ctx, ggml_get_name(layer.ffn_norm_a));
+        copy_tensor_by_name(layer.ffn_norm_b,       f_ggml_ctx, ggml_get_name(layer.ffn_norm_b));
+        copy_tensor_by_name(layer.w1_a,             f_ggml_ctx, ggml_get_name(layer.w1_a));
+        copy_tensor_by_name(layer.w1_b,             f_ggml_ctx, ggml_get_name(layer.w1_b));
+        copy_tensor_by_name(layer.w2_a,             f_ggml_ctx, ggml_get_name(layer.w2_a));
+        copy_tensor_by_name(layer.w2_b,             f_ggml_ctx, ggml_get_name(layer.w2_b));
+        copy_tensor_by_name(layer.w3_a,             f_ggml_ctx, ggml_get_name(layer.w3_a));
+        copy_tensor_by_name(layer.w3_b,             f_ggml_ctx, ggml_get_name(layer.w3_b));
+    }
+}
+
+static void save_llama_lora_gguf(struct gguf_context * fctx, struct my_llama_model * model, struct my_llama_lora * lora) {
+    const char * arch = "llama";
+    enum llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
+
+    std::vector<char> keybuf;
+    keybuf.resize(512);
+    auto kv = [arch, &keybuf](const char * key) -> const char * {
+        snprintf(keybuf.data(), keybuf.size(), key, arch);
+        return keybuf.data();
+    };
+
+    gguf_set_val_str(fctx, LLM_KV_GENERAL_ARCHITECTURE, arch);
+    gguf_set_val_u32(fctx, LLM_KV_GENERAL_FILE_TYPE, ftype);
+
+    gguf_set_val_u32(fctx, kv(LLM_KV_CONTEXT_LENGTH),              model->hparams.n_ctx);
+    gguf_set_val_u32(fctx, kv(LLM_KV_EMBEDDING_LENGTH),            model->hparams.n_embd);
+    gguf_set_val_u32(fctx, kv(LLM_KV_FEED_FORWARD_LENGTH),         model->hparams.n_ff);
+    gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT),        model->hparams.n_head);
+    gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV),     model->hparams.n_head_kv);
+    gguf_set_val_u32(fctx, kv(LLM_KV_BLOCK_COUNT),                 model->hparams.n_layer);
+    gguf_set_val_u32(fctx, kv(LLM_KV_ROPE_DIMENSION_COUNT),        model->hparams.n_embd_head());
+    gguf_set_val_f32(fctx, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS), model->hparams.f_norm_rms_eps);
+    gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_FREQ_BASE),              model->hparams.rope_freq_base);
+    gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_SCALE_LINEAR),           model->hparams.rope_freq_scale);
+
+    gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_TOKEN_EMBD,   lora->hparams.n_rank_tok_embeddings);
+    gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_OUTPUT_NORM,  lora->hparams.n_rank_norm);
+    gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_OUTPUT,       lora->hparams.n_rank_output);
+    gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_ATTN_NORM,    lora->hparams.n_rank_attention_norm);
+    gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_ATTN_Q,       lora->hparams.n_rank_wq);
+    gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_ATTN_K,       lora->hparams.n_rank_wk);
+    gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_ATTN_V,       lora->hparams.n_rank_wv);
+    gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_ATTN_OUT,     lora->hparams.n_rank_wo);
+    gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_FFN_NORM,     lora->hparams.n_rank_ffn_norm);
+    gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_FFN_GATE,     lora->hparams.n_rank_w1);
+    gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_FFN_DOWN,     lora->hparams.n_rank_w2);
+    gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_FFN_UP,       lora->hparams.n_rank_w3);
+
+    gguf_add_tensor(fctx, lora->tok_embeddings_a);
+    gguf_add_tensor(fctx, lora->tok_embeddings_b);
+    gguf_add_tensor(fctx, lora->norm_a);
+    gguf_add_tensor(fctx, lora->norm_b);
+    gguf_add_tensor(fctx, lora->output_a);
+    gguf_add_tensor(fctx, lora->output_b);
+
+    for (uint32_t i = 0; i < lora->layers.size(); ++i) {
+        auto & layer = lora->layers[i];
+
+        gguf_add_tensor(fctx, layer.attention_norm_a);
+        gguf_add_tensor(fctx, layer.attention_norm_b);
+        gguf_add_tensor(fctx, layer.wq_a);
+        gguf_add_tensor(fctx, layer.wq_b);
+        gguf_add_tensor(fctx, layer.wk_a);
+        gguf_add_tensor(fctx, layer.wk_b);
+        gguf_add_tensor(fctx, layer.wv_a);
+        gguf_add_tensor(fctx, layer.wv_b);
+        gguf_add_tensor(fctx, layer.wo_a);
+        gguf_add_tensor(fctx, layer.wo_b);
+        gguf_add_tensor(fctx, layer.ffn_norm_a);
+        gguf_add_tensor(fctx, layer.ffn_norm_b);
+        gguf_add_tensor(fctx, layer.w1_a);
+        gguf_add_tensor(fctx, layer.w1_b);
+        gguf_add_tensor(fctx, layer.w2_a);
+        gguf_add_tensor(fctx, layer.w2_b);
+        gguf_add_tensor(fctx, layer.w3_a);
+        gguf_add_tensor(fctx, layer.w3_b);
+    }
+}
+
+static void load_checkpoint_lora_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
+    std::string train_type = LLM_KV_TRAINING_TYPE_FINETUNE_LORA;
+    GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE);
+    GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_FINETUNE_LORA);
+
+    load_train_state_gguf(fctx, f_ggml_ctx, train);
+    load_llama_lora_gguf(fctx, f_ggml_ctx, model, lora);
+}
+
+static void save_checkpoint_lora_gguf(struct gguf_context * fctx, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
+    gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_FINETUNE_LORA);
+    save_llama_lora_gguf(fctx, model, lora);
+    save_train_state_gguf(fctx, train);
+}
+
+static bool load_checkpoint_lora_file(const char * filename, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
+    struct ggml_context * f_ggml_ctx;
+    struct gguf_init_params params;
+    params.no_alloc = false;
+    params.ctx = &f_ggml_ctx;
+    struct gguf_context * fctx = gguf_init_from_file(filename, params);
+    if (fctx == NULL) {
+        return false;
+    }
+
+    load_checkpoint_lora_gguf(fctx, f_ggml_ctx, model, lora, train);
+
+    gguf_free(fctx);
+    return true;
+}
+
+static void save_checkpoint_lora_file(const char * filename, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
+    printf("%s: saving to %s\n", __func__, filename);
+    struct gguf_context * fctx = gguf_init_empty();
+
+    save_checkpoint_lora_gguf(fctx, model, lora, train);
+
+    // write file
+    const bool only_meta = false;
+    gguf_write_to_file(fctx, filename, only_meta);
+    gguf_free(fctx);
+}
+
+struct llama_file {
+    // use FILE * so we don't have to re-open the file to mmap
+    FILE * fp;
+    size_t size;
+
+    llama_file(const char * fname, const char * mode) {
+        fp = std::fopen(fname, mode);
+        if (fp == NULL) {
+            size = 0;
+        } else {
+            seek(0, SEEK_END);
+            size = tell();
+            seek(0, SEEK_SET);
+        }
+    }
+
+    size_t tell() const {
+#ifdef _WIN32
+        __int64 ret = _ftelli64(fp);
+#else
+        long ret = std::ftell(fp);
+#endif
+        GGML_ASSERT(ret != -1); // this really shouldn't fail
+        return (size_t) ret;
+    }
+
+    void seek(size_t offset, int whence) {
+#ifdef _WIN32
+        int ret = _fseeki64(fp, (__int64) offset, whence);
+#else
+        int ret = std::fseek(fp, (long) offset, whence);
+#endif
+        GGML_ASSERT(ret == 0); // same
+    }
+
+    void read_raw(void * ptr, size_t size) {
+        if (size == 0) {
+            return;
+        }
+        errno = 0;
+        std::size_t ret = std::fread(ptr, size, 1, fp);
+        if (ferror(fp)) {
+            die_fmt("read error: %s", strerror(errno));
+        }
+        if (ret != 1) {
+            die("unexpectedly reached end of file");
+        }
+    }
+
+    std::uint32_t read_u32() {
+        std::uint32_t ret;
+        read_raw(&ret, sizeof(ret));
+        return ret;
+    }
+
+    std::string read_string(std::uint32_t len) {
+        std::vector<char> chars(len);
+        read_raw(chars.data(), len);
+        return std::string(chars.data(), len);
+    }
+
+    void write_raw(const void * ptr, size_t size) {
+        if (size == 0) {
+            return;
+        }
+        errno = 0;
+        size_t ret = std::fwrite(ptr, size, 1, fp);
+        if (ret != 1) {
+            die_fmt("write error: %s", strerror(errno));
+        }
+    }
+
+    void write_u32(std::uint32_t val) {
+        write_raw(&val, sizeof(val));
+    }
+
+    ~llama_file() {
+        if (fp) {
+            std::fclose(fp);
+        }
+    }
+};
+
+static void write_tensor(struct llama_file * file, struct ggml_tensor * tensor, const char * name) {
+    if (tensor == NULL) {
+        file->write_u32(0);
+        file->write_u32(0);
+        file->write_u32(GGML_TYPE_F32);
+        file->seek((0-file->tell()) & 31, SEEK_CUR);
+        return;
+    }
+    if (name == NULL) {
+        name = ggml_get_name(tensor);
+    }
+    uint32_t name_len = strlen(name);
+    uint32_t nd = tensor->n_dims;
+    uint32_t ne[4] = { (uint32_t)tensor->ne[0],
+                       (uint32_t)tensor->ne[1],
+                       (uint32_t)tensor->ne[2],
+                       (uint32_t)tensor->ne[3] };
+    file->write_u32(nd);
+    file->write_u32(name_len);
+    file->write_u32(tensor->type);
+    file->write_raw(ne, sizeof(ne[0]) * nd);
+    file->write_raw(name, name_len);
+    file->seek((0-file->tell()) & 31, SEEK_CUR);
+    file->write_raw(tensor->data, ggml_nbytes(tensor));
+}
+
+static void save_as_llama_lora(const char * filename, struct my_llama_lora * lora) {
+    printf("%s: saving to %s\n", __func__, filename);
+    struct llama_file file(filename, "wb");
+    if (file.fp == NULL) {
+        return;
+    }
+
+    std::vector<char> tn_buf;
+    tn_buf.resize(GGML_MAX_NAME);
+
+    auto tn = [&tn_buf](const char * key, const char * suffix) -> const char * {
+        snprintf(tn_buf.data(), tn_buf.size(), "%s%s", key, suffix);
+        return tn_buf.data();
+    };
+
+    auto tni = [&tn_buf](const char * key, int bid, const char * suffix) -> const char * {
+        snprintf(tn_buf.data(), tn_buf.size(), key, bid);
+        std::string s = tn_buf.data();
+        snprintf(tn_buf.data(), tn_buf.size(), "%s%s", s.c_str(), suffix);
+        return tn_buf.data();
+    };
+
+    uint32_t LLAMA_FILE_MAGIC_LORA = 0x67676C61; // 'ggla'
+    // write_magic
+    file.write_u32(LLAMA_FILE_MAGIC_LORA);   // magic
+    file.write_u32(1); // version
+    // write_hparams
+    file.write_u32(lora->hparams.lora_r);
+    file.write_u32(lora->hparams.lora_alpha);
+    // write tensors
+    write_tensor(&file, lora->tok_embeddings_a, tn(LLM_TENSOR_TOKEN_EMBD,  ".weight.loraA"));
+    write_tensor(&file, lora->tok_embeddings_b, tn(LLM_TENSOR_TOKEN_EMBD,  ".weight.loraB"));
+    write_tensor(&file, lora->norm_a,           tn(LLM_TENSOR_OUTPUT_NORM, ".weight.loraA"));
+    write_tensor(&file, lora->norm_b,           tn(LLM_TENSOR_OUTPUT_NORM, ".weight.loraB"));
+    write_tensor(&file, lora->output_a,         tn(LLM_TENSOR_OUTPUT,      ".weight.loraA"));
+    write_tensor(&file, lora->output_b,         tn(LLM_TENSOR_OUTPUT,      ".weight.loraB"));
+    for (uint32_t i = 0; i < lora->layers.size(); ++i) {
+        auto & layer = lora->layers[i];
+        write_tensor(&file, layer.attention_norm_a, tni(LLM_TENSOR_ATTN_NORM, i, ".weight.loraA"));
+        write_tensor(&file, layer.attention_norm_b, tni(LLM_TENSOR_ATTN_NORM, i, ".weight.loraB"));
+        write_tensor(&file, layer.wq_a,             tni(LLM_TENSOR_ATTN_Q,    i, ".weight.loraA"));
+        write_tensor(&file, layer.wq_b,             tni(LLM_TENSOR_ATTN_Q,    i, ".weight.loraB"));
+        write_tensor(&file, layer.wk_a,             tni(LLM_TENSOR_ATTN_K,    i, ".weight.loraA"));
+        write_tensor(&file, layer.wk_b,             tni(LLM_TENSOR_ATTN_K,    i, ".weight.loraB"));
+        write_tensor(&file, layer.wv_a,             tni(LLM_TENSOR_ATTN_V,    i, ".weight.loraA"));
+        write_tensor(&file, layer.wv_b,             tni(LLM_TENSOR_ATTN_V,    i, ".weight.loraB"));
+        write_tensor(&file, layer.wo_a,             tni(LLM_TENSOR_ATTN_OUT,  i, ".weight.loraA"));
+        write_tensor(&file, layer.wo_b,             tni(LLM_TENSOR_ATTN_OUT,  i, ".weight.loraB"));
+        write_tensor(&file, layer.ffn_norm_a,       tni(LLM_TENSOR_FFN_NORM,  i, ".weight.loraA"));
+        write_tensor(&file, layer.ffn_norm_b,       tni(LLM_TENSOR_FFN_NORM,  i, ".weight.loraB"));
+        write_tensor(&file, layer.w1_a,             tni(LLM_TENSOR_FFN_GATE,  i, ".weight.loraA"));
+        write_tensor(&file, layer.w1_b,             tni(LLM_TENSOR_FFN_GATE,  i, ".weight.loraB"));
+        write_tensor(&file, layer.w2_a,             tni(LLM_TENSOR_FFN_DOWN,  i, ".weight.loraA"));
+        write_tensor(&file, layer.w2_b,             tni(LLM_TENSOR_FFN_DOWN,  i, ".weight.loraB"));
+        write_tensor(&file, layer.w3_a,             tni(LLM_TENSOR_FFN_UP,    i, ".weight.loraA"));
+        write_tensor(&file, layer.w3_b,             tni(LLM_TENSOR_FFN_UP,    i, ".weight.loraB"));
+    }
+}
+
+struct train_params {
+    struct train_params_common common;
+
+    const char * fn_model_base;
+    const char * fn_lora_out;
+
+    bool only_write_lora;
+
+    float f_norm_rms_eps;
+    float rope_freq_base;
+    float rope_freq_scale;
+
+    bool custom_f_norm_rms_eps;
+    bool custom_rope_freq_base;
+    bool custom_rope_freq_scale;
+
+    int32_t lora_r;
+    int32_t lora_alpha;
+    bool custom_lora_alpha;
+
+    uint32_t n_rank_attention_norm;
+    uint32_t n_rank_wq;
+    uint32_t n_rank_wk;
+    uint32_t n_rank_wv;
+    uint32_t n_rank_wo;
+    uint32_t n_rank_ffn_norm;
+    uint32_t n_rank_w1;
+    uint32_t n_rank_w2;
+    uint32_t n_rank_w3;
+    uint32_t n_rank_tok_embeddings;
+    uint32_t n_rank_norm;
+    uint32_t n_rank_output;
+
+    bool custom_n_rank_attention_norm;
+    bool custom_n_rank_wq;
+    bool custom_n_rank_wk;
+    bool custom_n_rank_wv;
+    bool custom_n_rank_wo;
+    bool custom_n_rank_ffn_norm;
+    bool custom_n_rank_w1;
+    bool custom_n_rank_w2;
+    bool custom_n_rank_w3;
+    bool custom_n_rank_tok_embeddings;
+    bool custom_n_rank_norm;
+    bool custom_n_rank_output;
+};
+
+static struct train_params get_default_train_params() {
+    struct train_params params;
+    params.common = get_default_train_params_common();
+    params.fn_model_base     = "";
+    params.fn_lora_out       = "ggml-lora-ITERATION-f32.gguf";
+
+    params.only_write_lora = false;
+
+    params.f_norm_rms_eps  = 1e-5f;
+    params.rope_freq_base  = 10000.0f;
+    params.rope_freq_scale = 1.0f;
+
+    params.custom_f_norm_rms_eps  = false;
+    params.custom_rope_freq_base  = false;
+    params.custom_rope_freq_scale = false;
+
+    params.lora_r      = 4;
+    params.lora_alpha  = 4;
+    params.custom_lora_alpha = false;
+
+    params.n_rank_attention_norm = 1;
+    params.n_rank_wq             = 4;
+    params.n_rank_wk             = 4;
+    params.n_rank_wv             = 4;
+    params.n_rank_wo             = 4;
+    params.n_rank_ffn_norm       = 1;
+    params.n_rank_w1             = 4;
+    params.n_rank_w2             = 4;
+    params.n_rank_w3             = 4;
+    params.n_rank_tok_embeddings = 4;
+    params.n_rank_norm           = 1;
+    params.n_rank_output         = 4;
+
+    params.custom_n_rank_attention_norm = false;
+    params.custom_n_rank_wq             = false;
+    params.custom_n_rank_wk             = false;
+    params.custom_n_rank_wv             = false;
+    params.custom_n_rank_wo             = false;
+    params.custom_n_rank_ffn_norm       = false;
+    params.custom_n_rank_w1             = false;
+    params.custom_n_rank_w2             = false;
+    params.custom_n_rank_w3             = false;
+    params.custom_n_rank_tok_embeddings = false;
+    params.custom_n_rank_norm           = false;
+    params.custom_n_rank_output         = false;
+
+    return params;
+}
+
+static void train_print_usage(int argc, char ** argv, const struct train_params * params) {
+    fprintf(stderr, "usage: %s [options]\n", argv[0]);
+    fprintf(stderr, "\n");
+    fprintf(stderr, "options:\n");
+    fprintf(stderr, "  -h, --help                 show this help message and exit\n");
+
+    fprintf(stderr, "  --model-base FNAME         model path from which to load base model (default '%s')\n", params->fn_model_base);
+    fprintf(stderr, "  --lora-out FNAME           path to save llama lora (default '%s')\n", params->fn_lora_out);
+    fprintf(stderr, "  --only-write-lora          only save llama lora, don't do any training.  use this if you only want to convert a checkpoint to a lora adapter.\n");
+    fprintf(stderr, "  --norm-rms-eps F           RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps);
+    fprintf(stderr, "  --rope-freq-base F         Frequency base for ROPE (default %f)\n", params->rope_freq_base);
+    fprintf(stderr, "  --rope-freq-scale F        Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
+    fprintf(stderr, "  --lora-alpha N             LORA alpha : resulting LORA scaling is alpha/r. (default %d)\n", params->lora_alpha);
+    fprintf(stderr, "  --lora-r N                 LORA r: default rank. Also specifies resulting scaling together with lora-alpha. (default %d)\n", params->lora_r);
+    fprintf(stderr, "  --rank-att-norm N          LORA rank for attention norm tensor, overrides default rank. Norm tensors should generally have rank 1.\n");
+    fprintf(stderr, "  --rank-ffn-norm N          LORA rank for feed-forward norm tensor, overrides default rank. Norm tensors should generally have rank 1.\n");
+    fprintf(stderr, "  --rank-out-norm N          LORA rank for output norm tensor, overrides default rank. Norm tensors should generally have rank 1.\n");
+    fprintf(stderr, "  --rank-tok-embd N          LORA rank for token embeddings tensor, overrides default rank.\n");
+    fprintf(stderr, "  --rank-out N               LORA rank for output tensor, overrides default rank.\n");
+    fprintf(stderr, "  --rank-wq N                LORA rank for wq tensor, overrides default rank.\n");
+    fprintf(stderr, "  --rank-wk N                LORA rank for wk tensor, overrides default rank.\n");
+    fprintf(stderr, "  --rank-wv N                LORA rank for wv tensor, overrides default rank.\n");
+    fprintf(stderr, "  --rank-wo N                LORA rank for wo tensor, overrides default rank.\n");
+    fprintf(stderr, "  --rank-w1 N                LORA rank for w1 tensor, overrides default rank.\n");
+    fprintf(stderr, "  --rank-w2 N                LORA rank for w2 tensor, overrides default rank.\n");
+    fprintf(stderr, "  --rank-w3 N                LORA rank for w3 tensor, overrides default rank.\n");
+
+    print_common_train_usage(argc, argv, &params->common);
+}
+
+static bool train_params_parse(int argc, char ** argv, struct train_params * params) {
+    bool invalid_param = false;
+    std::string arg;
+    struct train_params default_params = get_default_train_params();
+    const std::string arg_prefix = "--";
+
+    for (int i = 1; i < argc; i++) {
+        arg = argv[i];
+        if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
+            std::replace(arg.begin(), arg.end(), '_', '-');
+        }
+
+        if (consume_common_train_arg(argc, argv, &i, &params->common, &invalid_param)) {
+            if (invalid_param) {
+                break;
+            } else if (params->common.print_usage) {
+                train_print_usage(argc, argv, &default_params);
+                exit(0);
+            }
+        } else if (arg == "--model-base") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params->fn_model_base = argv[i];
+        } else if (arg == "--lora-out") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params->fn_lora_out = argv[i];
+        } else if (arg == "--only-write-lora") {
+            params->only_write_lora = true;
+        } else if (arg == "--norm-rms-eps") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params->f_norm_rms_eps = std::stof(argv[i]);
+            params->custom_f_norm_rms_eps = true;
+        } else if (arg == "--rope-freq-base") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params->rope_freq_base = std::stof(argv[i]);
+            params->custom_rope_freq_base = true;
+        } else if (arg == "--rope-freq-scale") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params->rope_freq_scale = std::stof(argv[i]);
+            params->custom_rope_freq_scale = true;
+        } else if (arg == "--lora-alpha") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params->lora_alpha = std::stoi(argv[i]);
+            params->custom_lora_alpha = true;
+        } else if (arg == "--lora-r") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params->lora_r = std::stoi(argv[i]);
+        } else if (arg == "--rank-att-norm") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params->n_rank_attention_norm = std::stoi(argv[i]);
+            params->custom_n_rank_attention_norm = true;
+        } else if (arg == "--rank-ffn-norm") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params->n_rank_ffn_norm = std::stoi(argv[i]);
+            params->custom_n_rank_ffn_norm = true;
+        } else if (arg == "--rank-out-norm") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params->n_rank_norm = std::stoi(argv[i]);
+            params->custom_n_rank_norm = true;
+        } else if (arg == "--rank-tok-embd") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params->n_rank_tok_embeddings = std::stoi(argv[i]);
+            params->custom_n_rank_tok_embeddings = true;
+        } else if (arg == "--rank-out") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params->n_rank_output = std::stoi(argv[i]);
+            params->custom_n_rank_output = true;
+        } else if (arg == "--rank-wq") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params->n_rank_wq = std::stoi(argv[i]);
+            params->custom_n_rank_wq = true;
+        } else if (arg == "--rank-wk") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params->n_rank_wk = std::stoi(argv[i]);
+            params->custom_n_rank_wk = true;
+        } else if (arg == "--rank-wv") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params->n_rank_wv = std::stoi(argv[i]);
+            params->custom_n_rank_wv = true;
+        } else if (arg == "--rank-wo") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params->n_rank_wo = std::stoi(argv[i]);
+            params->custom_n_rank_wo = true;
+        } else if (arg == "--rank-w1") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params->n_rank_w1 = std::stoi(argv[i]);
+            params->custom_n_rank_w1 = true;
+        } else if (arg == "--rank-w2") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params->n_rank_w2 = std::stoi(argv[i]);
+            params->custom_n_rank_w2 = true;
+        } else if (arg == "--rank-w3") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params->n_rank_w3 = std::stoi(argv[i]);
+            params->custom_n_rank_w3 = true;
+        } else {
+            fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
+            train_print_usage(argc, argv, &default_params);
+            exit(1);
+        }
+    }
+    if (invalid_param) {
+        fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
+        train_print_usage(argc, argv, &default_params);
+        exit(1);
+    }
+    finish_processing_train_args(&params->common);
+    return true;
+}
+
+struct save_train_files_data {
+    const char            * fn_checkpoint_out;
+    const char            * fn_lora_out;
+    const char            * pattern_fn_it;
+    const char            * fn_latest;
+    struct my_llama_model * model;
+    struct my_llama_lora  * lora;
+};
+
+static void save_train_files(void * vdata, struct train_state * train) {
+    struct save_train_files_data * data   = (struct save_train_files_data *) vdata;
+
+    int64_t iter = train->opt->iter;
+
+    if (strlen(data->fn_checkpoint_out) > 0) {
+        save_checkpoint_lora_file(get_train_filename(data->fn_checkpoint_out, data->pattern_fn_it, data->fn_latest, iter).c_str(), data->model, data->lora, train);
+        save_checkpoint_lora_file(get_train_filename(data->fn_checkpoint_out, data->pattern_fn_it, data->fn_latest, -1  ).c_str(), data->model, data->lora, train);
+    }
+    if (strlen(data->fn_lora_out) > 0) {
+        save_as_llama_lora(get_train_filename(data->fn_lora_out, data->pattern_fn_it, data->fn_latest, iter).c_str(), data->lora);
+        save_as_llama_lora(get_train_filename(data->fn_lora_out, data->pattern_fn_it, data->fn_latest, -1  ).c_str(), data->lora);
+    }
+}
+
+static int64_t get_parameter_count(struct my_llama_lora* lora) {
+    int64_t nx = 0;
+    nx += ggml_nelements(lora->tok_embeddings_a);
+    nx += ggml_nelements(lora->tok_embeddings_b);
+    nx += ggml_nelements(lora->norm_a);
+    nx += ggml_nelements(lora->norm_b);
+    nx += ggml_nelements(lora->output_a);
+    nx += ggml_nelements(lora->output_b);
+
+    for (uint32_t i = 0; i < lora->layers.size(); ++i) {
+        auto & layer = lora->layers[i];
+        nx += ggml_nelements(layer.attention_norm_a);
+        nx += ggml_nelements(layer.attention_norm_b);
+        nx += ggml_nelements(layer.wq_a);
+        nx += ggml_nelements(layer.wq_b);
+        nx += ggml_nelements(layer.wk_a);
+        nx += ggml_nelements(layer.wk_b);
+        nx += ggml_nelements(layer.wv_a);
+        nx += ggml_nelements(layer.wv_b);
+        nx += ggml_nelements(layer.wo_a);
+        nx += ggml_nelements(layer.wo_b);
+        nx += ggml_nelements(layer.ffn_norm_a);
+        nx += ggml_nelements(layer.ffn_norm_b);
+        nx += ggml_nelements(layer.w1_a);
+        nx += ggml_nelements(layer.w1_b);
+        nx += ggml_nelements(layer.w2_a);
+        nx += ggml_nelements(layer.w2_b);
+        nx += ggml_nelements(layer.w3_a);
+        nx += ggml_nelements(layer.w3_b);
+    }
+    return nx;
+}
+
+int main(int argc, char ** argv) {
+    struct train_params params = get_default_train_params();
+
+    if (!train_params_parse(argc, argv, &params)) {
+        return 1;
+    }
+
+    if (params.common.seed == LLAMA_DEFAULT_SEED) {
+        params.common.seed = time(NULL);
+    }
+    printf("%s: seed: %u\n", __func__, params.common.seed);
+    srand(params.common.seed);
+
+    struct llama_context_params llama_params = llama_context_default_params();
+    llama_params.vocab_only = false;
+
+    printf("%s: model base = '%s'\n", __func__, params.fn_model_base);
+    struct llama_model * lmodel = llama_load_model_from_file(params.fn_model_base, llama_params);
+    struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params);
+
+    struct my_llama_model model;
+    init_model(lmodel, &model, params.fn_model_base, params.common.n_ctx);
+
+    struct my_llama_lora lora;
+
+    struct train_state      * train = init_train_state();
+    struct ggml_opt_context * opt   = train->opt;
+
+    // set params from command line
+    if (params.custom_f_norm_rms_eps) {
+        model.hparams.f_norm_rms_eps  = params.f_norm_rms_eps;
+    }
+    if (params.custom_rope_freq_base) {
+        model.hparams.rope_freq_base  = params.rope_freq_base;
+    }
+    if (params.custom_rope_freq_scale) {
+        model.hparams.rope_freq_scale = params.rope_freq_scale;
+    }
+    lora.hparams.lora_r                = params.lora_r;
+    lora.hparams.lora_alpha            = params.custom_lora_alpha            ? params.lora_alpha            : params.lora_r;
+    uint32_t n_rank_attention_norm     = params.custom_n_rank_attention_norm ? params.n_rank_attention_norm : 1;
+    uint32_t n_rank_wq                 = params.custom_n_rank_wq             ? params.n_rank_wq             : params.lora_r;
+    uint32_t n_rank_wk                 = params.custom_n_rank_wk             ? params.n_rank_wk             : params.lora_r;
+    uint32_t n_rank_wv                 = params.custom_n_rank_wv             ? params.n_rank_wv             : params.lora_r;
+    uint32_t n_rank_wo                 = params.custom_n_rank_wo             ? params.n_rank_wo             : params.lora_r;
+    uint32_t n_rank_ffn_norm           = params.custom_n_rank_ffn_norm       ? params.n_rank_ffn_norm       : 1;
+    uint32_t n_rank_w1                 = params.custom_n_rank_w1             ? params.n_rank_w1             : params.lora_r;
+    uint32_t n_rank_w2                 = params.custom_n_rank_w2             ? params.n_rank_w2             : params.lora_r;
+    uint32_t n_rank_w3                 = params.custom_n_rank_w3             ? params.n_rank_w3             : params.lora_r;
+    uint32_t n_rank_tok_embeddings     = params.custom_n_rank_tok_embeddings ? params.n_rank_tok_embeddings : params.lora_r;
+    uint32_t n_rank_norm               = params.custom_n_rank_norm           ? params.n_rank_norm           : 1;
+    uint32_t n_rank_output             = params.custom_n_rank_output         ? params.n_rank_output         : params.lora_r;
+    lora.hparams.n_rank_attention_norm = n_rank_attention_norm;
+    lora.hparams.n_rank_wq             = n_rank_wq;
+    lora.hparams.n_rank_wk             = n_rank_wk;
+    lora.hparams.n_rank_wv             = n_rank_wv;
+    lora.hparams.n_rank_wo             = n_rank_wo;
+    lora.hparams.n_rank_ffn_norm       = n_rank_ffn_norm;
+    lora.hparams.n_rank_w1             = n_rank_w1;
+    lora.hparams.n_rank_w2             = n_rank_w2;
+    lora.hparams.n_rank_w3             = n_rank_w3;
+    lora.hparams.n_rank_tok_embeddings = n_rank_tok_embeddings;
+    lora.hparams.n_rank_norm           = n_rank_norm;
+    lora.hparams.n_rank_output         = n_rank_output;
+
+    // set opt params from command line
+    opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
+    opt->params.print_forward_graph     = false;
+    opt->params.print_backward_graph    = false;
+    opt->params.n_threads               = params.common.n_threads;
+    opt->params.past                    = params.common.opt_past;
+    opt->params.delta                   = params.common.opt_delta;
+    opt->params.max_no_improvement      = params.common.opt_max_no_improvement;
+    opt->params.n_gradient_accumulation = params.common.n_gradient_accumulation;
+    opt->params.adam.n_iter             = params.common.adam_n_iter;
+    opt->params.adam.sched              = 1.0f;
+    opt->params.adam.alpha              = params.common.adam_alpha;
+    opt->params.adam.decay              = params.common.adam_decay;
+    opt->params.adam.decay_min_ndim     = params.common.adam_decay_min_ndim;
+    opt->params.adam.beta1              = params.common.adam_beta1;
+    opt->params.adam.beta2              = params.common.adam_beta2;
+    opt->params.adam.gclip              = params.common.adam_gclip;
+    opt->params.adam.eps_f              = params.common.adam_eps_f;
+
+    ggml_allocr * alloc = NULL;
+
+    printf("%s: init model\n", __func__);
+    bool existed = load_checkpoint_lora_file(params.common.fn_checkpoint_in, &model, &lora, train);
+
+    if (existed) {
+        // overwrite last n_ctx with user provided n_ctx
+        if (params.common.custom_n_ctx) {
+            model.hparams.n_ctx = params.common.n_ctx;
+        }
+
+        const bool opt_param_count_changed = (
+           (lora.hparams.n_rank_attention_norm != n_rank_attention_norm)
+        || (lora.hparams.n_rank_wq             != n_rank_wq)
+        || (lora.hparams.n_rank_wk             != n_rank_wk)
+        || (lora.hparams.n_rank_wv             != n_rank_wv)
+        || (lora.hparams.n_rank_wo             != n_rank_wo)
+        || (lora.hparams.n_rank_ffn_norm       != n_rank_ffn_norm)
+        || (lora.hparams.n_rank_w1             != n_rank_w1)
+        || (lora.hparams.n_rank_w2             != n_rank_w2)
+        || (lora.hparams.n_rank_w3             != n_rank_w3)
+        || (lora.hparams.n_rank_tok_embeddings != n_rank_tok_embeddings)
+        || (lora.hparams.n_rank_norm           != n_rank_norm)
+        || (lora.hparams.n_rank_output         != n_rank_output)
+        );
+
+        const bool opt_past_changed = opt->params.past != params.common.opt_past;
+
+        if (opt_param_count_changed) {
+            print_lora_params(&lora.hparams);
+            die("Provided rank differs from checkpoint file. To use different rank start finetune from scratch with empty input checkpoint, e.g --checkpoint-in ''. Aborting.");
+            // need to discard previous optimizer gradient statistics and opt_init with new shapes
+            // TODO
+        }
+        if (opt_past_changed) {
+            die("Optimizer parameter '--opt-past N' differs from checkpoint file. To use different value finetune from scratch with empty input checkpoint, e.g --checkpoint-in ''. Aborting");
+            // need to discard previous optimizer past function value statistics and opt_init with new shapes
+            // TODO
+        }
+    } else { // existed == false
+        init_lora(&model, &lora);
+        randomize_lora(&lora, params.common.seed, 0.0f, 1.0f, -1.0f, +1.0f);
+        if (!params.only_write_lora) {
+            ggml_opt_init(opt->ctx, opt, opt->params, get_parameter_count(&lora));
+        }
+    }
+    opt->iter = train->train_its;
+
+    print_params(&model.hparams);
+    print_lora_params(&lora.hparams);
+    printf("%s: total train_iterations %llu\n", __func__, (long long unsigned) train->train_its);
+    printf("%s: seen train_samples     %llu\n", __func__, (long long unsigned) train->train_samples);
+    printf("%s: seen train_tokens      %llu\n", __func__, (long long unsigned) train->train_tokens);
+    printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) train->train_epochs);
+    printf("%s: lora_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(lora.ctx) + lora.data.size()), (float) (ggml_used_mem(lora.ctx) + lora.data.size()) / (1024.0f*1024.0f));
+
+    if (params.only_write_lora) {
+        save_train_files_data save_data;
+        save_data.fn_checkpoint_out = "";
+        save_data.fn_lora_out       = params.fn_lora_out;
+        save_data.pattern_fn_it     = params.common.pattern_fn_it;
+        save_data.fn_latest         = params.common.fn_latest;
+        save_data.model             = &model;
+        save_data.lora              = &lora;
+
+        save_train_files(&save_data, train);
+
+        free_train_state(train);
+        ggml_free(lora.ctx);
+        llama_free(lctx);
+        llama_free_model(lmodel);
+        return 0;
+    }
+
+    printf("%s: opt_size  = %zu bytes (%.1f MB)\n", __func__, ggml_get_mem_size(opt->ctx), (float) ggml_get_mem_size(opt->ctx) / (1024.0f*1024.0f));
+    printf("%s: opt iter %d\n", __func__, opt->iter);
+
+    int n_tokens = model.hparams.n_ctx;
+    int n_vocab  = model.hparams.n_vocab;
+    int n_batch  = params.common.n_batch;
+
+
+    std::vector<uint8_t> mem_input_data;
+    std::vector<uint8_t> mem_compute_data;
+
+    // context for input tensors without their data
+    struct ggml_init_params ctx_input_params = {
+        ggml_tensor_overhead() * 2, // mem_size
+        NULL,                       // mem_buffer
+        true,                       // no_alloc
+    };
+    struct ggml_context * ctx_input = ggml_init(ctx_input_params);
+
+    // the input tensors
+    struct ggml_tensor * tokens_input  = ggml_new_tensor_2d(ctx_input, GGML_TYPE_I32, n_tokens, n_batch);
+    struct ggml_tensor * target_probs  = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab,  n_tokens, n_batch);
+
+    // measure required memory for input tensors
+    alloc = ggml_allocr_new_measure(tensor_alignment);
+    ggml_allocr_alloc(alloc, tokens_input);
+    ggml_allocr_alloc(alloc, target_probs);
+    size_t max_input_size = ggml_allocr_max_size(alloc) + tensor_alignment;
+    ggml_allocr_free(alloc);
+    printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
+
+    // allocate input tensors
+    mem_input_data.resize(max_input_size);
+    alloc = ggml_allocr_new(mem_input_data.data(), mem_input_data.size(), tensor_alignment);
+    ggml_allocr_alloc(alloc, tokens_input);
+    ggml_allocr_alloc(alloc, target_probs);
+    ggml_allocr_free(alloc);
+
+    // context for compute tensors without their data
+    size_t estimated_compute_size_wo_data = (
+        ggml_tensor_overhead()*GGML_MAX_NODES*2
+      + (GGML_OBJECT_SIZE+GGML_GRAPH_SIZE)*(
+            params.common.use_checkpointing ? 3 : 2
+        )
+    );
+    struct ggml_init_params ctx_compute_params = {
+        estimated_compute_size_wo_data, // mem_size
+        NULL,                           // mem_buffer
+        true,                           // no_alloc
+    };
+    struct ggml_context * ctx_compute = NULL;
+
+    struct ggml_tensor * loss   = NULL;
+    struct ggml_tensor * logits = NULL;
+
+    struct ggml_cgraph * gf     = NULL;
+    struct ggml_cgraph * gb     = NULL;
+    struct ggml_cgraph * gb_tmp = NULL;
+
+    // measure required memory for compute tensors
+    size_t best_compute_size = SIZE_MAX;
+    enum ggml_cgraph_eval_order best_order = GGML_CGRAPH_EVAL_ORDER_COUNT;
+    // find best evaluation order
+    for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
+        ctx_compute = ggml_init(ctx_compute_params);
+        alloc = ggml_allocr_new_measure(tensor_alignment);
+        gf = ggml_new_graph(ctx_compute);
+        gf->order = (enum ggml_cgraph_eval_order) order;
+        gb = ggml_new_graph(ctx_compute);
+        gb_tmp = params.common.use_checkpointing
+            ? ggml_new_graph(ctx_compute)
+            : NULL;
+        loss = llama_build_lora_finetune_graphs(
+            &model, &lora, alloc, ctx_compute,
+            gf, gb, gb_tmp,
+            &logits, tokens_input, target_probs,
+            n_tokens, n_batch,
+            params.common.use_flash,
+            params.common.use_checkpointing
+        );
+        size_t max_compute_size = ggml_allocr_max_size(alloc) + tensor_alignment;
+        if (max_compute_size < best_compute_size) {
+            best_compute_size = max_compute_size;
+            best_order = gf->order;
+        }
+        ggml_allocr_free(alloc);
+        ggml_free(ctx_compute);
+    }
+    size_t max_compute_size = best_compute_size;
+    printf("%s: compute_size = %zu bytes (%.1f MB)\n", __func__, max_compute_size, (float) max_compute_size / (1024.0f*1024.0f));
+    printf("%s: evaluation order = %s\n", __func__,
+        (best_order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? "LEFT_TO_RIGHT" :
+        (best_order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? "RIGHT_TO_LEFT" :
+        "invalid");
+
+    // allocate compute tensors
+    mem_compute_data.resize(max_compute_size);
+    ctx_compute = ggml_init(ctx_compute_params);
+    alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
+    gf = ggml_new_graph(ctx_compute);
+    gf->order = best_order;
+    gb = ggml_new_graph(ctx_compute);
+    gb_tmp = params.common.use_checkpointing
+        ? ggml_new_graph(ctx_compute)
+        : NULL;
+    loss = llama_build_lora_finetune_graphs(
+        &model, &lora, alloc, ctx_compute,
+        gf, gb, gb_tmp,
+        &logits, tokens_input, target_probs,
+        n_tokens, n_batch,
+        params.common.use_flash,
+        params.common.use_checkpointing
+    );
+    ggml_allocr_free(alloc);
+
+    // tokenize data
+    std::vector<llama_token> train_tokens;
+    std::vector<size_t> train_samples_begin;
+    std::vector<size_t> train_samples_size;
+    printf("%s: tokenize training data\n", __func__);
+    tokenize_file(lctx,
+            params.common.fn_train_data,
+            params.common.sample_start,
+            params.common.include_sample_start,
+            params.common.overlapping_samples,
+            n_tokens,
+            train_tokens,
+            train_samples_begin,
+            train_samples_size);
+    GGML_ASSERT(train_samples_begin.size() == train_samples_size.size());
+
+    printf("%s: number of training tokens: %zu\n", __func__, train_tokens.size());
+
+    std::vector<size_t> token_noccurs;
+    token_noccurs.resize(model.hparams.n_vocab, 0);
+    for (unsigned int i = 0; i < train_tokens.size(); ++i) {
+        ++token_noccurs[train_tokens[i]];
+    }
+    int n_unique_tokens = 0;
+    for (unsigned int i = 0; i < token_noccurs.size(); ++i) {
+        if (token_noccurs[i] == 0) continue;
+        ++n_unique_tokens;
+    }
+    printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens);
+
+    size_t shuffle_samples_hash = compute_samples_hash(params.common.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size());
+    const bool changed_train_data = (shuffle_samples_hash != train->shuffle_samples_hash) || (train->shuffle_sample_count != train_samples_size.size());
+    if (changed_train_data) {
+        printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__);
+    }
+    if (params.common.force_reshuffle) {
+        printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__);
+    }
+    if ((train->shuffle_rng_state_current == "") || changed_train_data || params.common.force_reshuffle) {
+        train->shuffle_rng_state_current = mt19937_seed_to_state(params.common.seed);
+        train->shuffle_sample_count = train_samples_size.size();
+        train->shuffle_next_sample = 0;
+        train->shuffle_samples_hash = shuffle_samples_hash;
+    }
+    std::vector<size_t> train_shuffled_samples_offs;
+    std::vector<size_t> train_shuffled_samples_begin;
+    std::vector<size_t> train_shuffled_samples_size;
+    train_shuffled_samples_offs.resize(train_samples_begin.size());
+    train_shuffled_samples_begin.resize(train_samples_begin.size());
+    train_shuffled_samples_size.resize(train_samples_size.size());
+    train->shuffle_rng_state_next = shuffle_samples(
+        train->shuffle_rng_state_current,
+        train_shuffled_samples_offs.data(),
+        train_shuffled_samples_begin.data(),
+        train_shuffled_samples_size.data(),
+        train_samples_begin.data(),
+        train_samples_size.data(),
+        train_samples_size.size());
+
+    printf("%s: begin training\n", __func__);
+
+    save_train_files_data save_data;
+    save_data.fn_checkpoint_out = params.common.fn_checkpoint_out;
+    save_data.fn_lora_out       = params.fn_lora_out;
+    save_data.pattern_fn_it     = params.common.pattern_fn_it;
+    save_data.fn_latest         = params.common.fn_latest;
+    save_data.model             = &model;
+    save_data.lora              = &lora;
+
+    struct train_opt_callback_data opt_cb_data;
+    opt_cb_data.params                 = &params.common;
+    opt_cb_data.train                  = train;
+    opt_cb_data.save_cb                = &save_train_files;
+    opt_cb_data.save_data              = &save_data;
+    opt_cb_data.lctx                   = lctx;
+    opt_cb_data.last_save_iter         = opt->iter;
+    opt_cb_data.tokens_data            = train_tokens.data();
+    opt_cb_data.tokens_size            = train_tokens.size();
+    opt_cb_data.samples_begin          = train_samples_begin.data();
+    opt_cb_data.samples_size           = train_samples_size.data();
+    opt_cb_data.shuffled_samples_offs  = train_shuffled_samples_offs.data();
+    opt_cb_data.shuffled_samples_begin = train_shuffled_samples_begin.data();
+    opt_cb_data.shuffled_samples_size  = train_shuffled_samples_size.data();
+    opt_cb_data.samples_count          = train_samples_size.size();
+    opt_cb_data.tokens_input           = tokens_input;
+    opt_cb_data.target_probs           = target_probs;
+    opt_cb_data.first_iter             = opt->iter;
+    opt_cb_data.first_epoch            = train->train_epochs;
+    opt_cb_data.iter_at_last_epoch     = -1;
+    opt_cb_data.last_time              = ggml_time_ms();
+    opt_cb_data.millis_per_iter        = 0.0;
+
+    // measure required memory for work buffer
+    size_t max_work_size = ggml_graph_plan(gb, params.common.n_threads).work_size + GGML_OBJECT_SIZE;
+    printf("%s: work_size = %zu bytes (%.1f MB)\n", __func__, max_work_size, (float) max_work_size / (1024.0f*1024.0f));
+
+    // context for work buffer
+    struct ggml_init_params ctx_work_params = {
+        max_work_size, // mem_size
+        NULL,          // mem_buffer
+        false,         // no_alloc
+    };
+    struct ggml_context * ctx_work = ggml_init(ctx_work_params);
+
+    int64_t t0 = ggml_time_ms();
+
+    ggml_opt_resume_g(ctx_work, opt, loss, gf, gb, &train_opt_callback, (void *) &opt_cb_data);
+
+    ggml_free(ctx_work);
+    ggml_free(ctx_compute);
+    ggml_free(ctx_input);
+
+    int64_t t1 = ggml_time_ms();
+    printf("%s: total training time: ", __func__);
+    print_duration((double) (t1 - t0));
+    printf("\n");
+
+    int new_iters = opt->iter - opt_cb_data.last_save_iter;
+    if (new_iters > 0) {
+        train->train_its     += new_iters;
+        train->train_tokens  += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens;
+
+        save_train_files(&save_data, train);
+        opt_cb_data.last_save_iter = opt->iter;
+    }
+
+    ggml_free(opt->ctx);
+    free_train_state(train);
+    ggml_free(lora.ctx);
+    llama_free(lctx);
+    llama_free_model(lmodel);
+    return 0;
+}

+ 17 - 1
examples/server/server.cpp

@@ -956,7 +956,23 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
                 invalid_param = true;
                 break;
             }
-            params.lora_adapter = argv[i];
+            params.lora_adapter.push_back({argv[i], 1.0f});
+            params.use_mmap = false;
+        }
+        else if (arg == "--lora-scaled")
+        {
+            if (++i >= argc)
+            {
+                invalid_param = true;
+                break;
+            }
+            const char * lora_adapter = argv[i];
+            if (++i >= argc)
+            {
+                invalid_param = true;
+                break;
+            }
+            params.lora_adapter.push_back({lora_adapter, std::stof(argv[i])});
             params.use_mmap = false;
         }
         else if (arg == "--lora-base")

+ 8 - 3
examples/train-text-from-scratch/README.md

@@ -10,9 +10,9 @@ wget https://raw.githubusercontent.com/brunoklein99/deep-learning-notes/master/s
 ./bin/train-text-from-scratch \
         --vocab-model ../models/ggml-vocab-llama.gguf \
         --ctx 64 --embd 256 --head 8 --layer 16 \
-        --checkpoint-in  chk-shakespeare-256x16.gguf \
-        --checkpoint-out chk-shakespeare-256x16.gguf \
-        --model-out ggml-shakespeare-256x16-f32.gguf \
+        --checkpoint-in  chk-shakespeare-256x16-LATEST.gguf \
+        --checkpoint-out chk-shakespeare-256x16-ITERATION.gguf \
+        --model-out ggml-shakespeare-256x16-f32-ITERATION.gguf \
         --train-data "shakespeare.txt" \
         -t 6 -b 16 --seed 1 --adam-iter 256 \
         --no-checkpointing
@@ -20,3 +20,8 @@ wget https://raw.githubusercontent.com/brunoklein99/deep-learning-notes/master/s
 # predict
 ./bin/main -m ggml-shakespeare-256x16-f32.gguf
 ```
+
+Output files will be saved every N iterations (config with `--save-every N`).
+The pattern "ITERATION" in the output filenames will be replaced with the iteration number and "LATEST" for the latest output.
+
+To train GGUF models just pass them to `--checkpoint-in FN`.

+ 8 - 4
examples/train-text-from-scratch/convert-train-checkpoint-to-gguf.py

@@ -47,10 +47,13 @@ LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS           = "optimizer.lbfgs.memory_ys"
 LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S            = "optimizer.lbfgs.memory_s"
 LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y            = "optimizer.lbfgs.memory_y"
 
-LLM_KV_TRAINING_FILE_VERSION    = "training.file_version"
-LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count"
-LLM_KV_TRAINING_SAMPLE_COUNT    = "training.sample_count"
-LLM_KV_TRAINING_TOKEN_COUNT     = "training.token_count"
+LLM_KV_TRAINING_TYPE_TRAIN_MODEL   = "train_model"
+LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora"
+LLM_KV_TRAINING_TYPE               = "training.type"
+LLM_KV_TRAINING_FILE_VERSION       = "training.file_version"
+LLM_KV_TRAINING_ITERATION_COUNT    = "training.iteration_count"
+LLM_KV_TRAINING_SAMPLE_COUNT       = "training.sample_count"
+LLM_KV_TRAINING_TOKEN_COUNT        = "training.token_count"
 
 class Tensor:
     def __init__(self, dtype='f', ne=None):
@@ -460,6 +463,7 @@ class Checkpoint:
         gguf_writer.add_file_type(gguf.GGMLQuantizationType.F32)
         gguf_writer.add_layer_norm_rms_eps(1e-5)
         gguf_writer.add_uint32(LLM_KV_TRAINING_FILE_VERSION,    0)
+        gguf_writer.add_string(LLM_KV_TRAINING_TYPE,            LLM_KV_TRAINING_TYPE_TRAIN_MODEL)
         gguf_writer.add_uint32(LLM_KV_TRAINING_ITERATION_COUNT, self.train_its)
         gguf_writer.add_uint32(LLM_KV_TRAINING_SAMPLE_COUNT,    self.train_samples)
         gguf_writer.add_uint32(LLM_KV_TRAINING_TOKEN_COUNT,     self.train_tokens)

File diff ditekan karena terlalu besar
+ 148 - 842
examples/train-text-from-scratch/train-text-from-scratch.cpp


+ 8 - 2
ggml-alloc.c

@@ -77,7 +77,7 @@ struct free_block {
     size_t size;
 };
 
-#define MAX_FREE_BLOCKS 128
+#define MAX_FREE_BLOCKS 256
 
 struct ggml_allocr {
     void * data;
@@ -187,6 +187,7 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
     }
 
     tensor->data = addr;
+    AT_PRINTF("%s: allocated data at %p\n", __func__, tensor->data);
 
 #ifdef GGML_ALLOCATOR_DEBUG
     add_allocated_tensor(alloc, tensor);
@@ -218,7 +219,8 @@ static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tens
 
     size_t size = ggml_allocr_get_alloc_size(alloc, tensor);
     size = aligned_offset(NULL, size, alloc->alignment);
-    AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks);
+    AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
+    AT_PRINTF("%s: alloc->data = %p alloc->data+alloc->size = %p alloc->data+alloc->max_size = %p\n", __func__, alloc->data, (char*)alloc->data + alloc->size, (char*)alloc->data + alloc->max_size);
 
 #ifdef GGML_ALLOCATOR_DEBUG
     remove_allocated_tensor(alloc, tensor);
@@ -631,3 +633,7 @@ static size_t ggml_allocr_alloc_graph_tensors_n(
 size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) {
     return ggml_allocr_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
 }
+
+size_t ggml_allocr_max_size(struct ggml_allocr * alloc) {
+    return alloc->max_size;
+}

+ 1 - 0
ggml-alloc.h

@@ -19,6 +19,7 @@ GGML_API bool   ggml_allocr_is_measure(struct ggml_allocr * alloc);
 GGML_API void   ggml_allocr_reset(struct ggml_allocr * alloc);
 GGML_API void   ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor);
 GGML_API size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph);
+GGML_API size_t ggml_allocr_max_size(struct ggml_allocr * alloc);
 
 
 #ifdef  __cplusplus

File diff ditekan karena terlalu besar
+ 600 - 90
ggml.c


+ 43 - 4
ggml.h

@@ -214,8 +214,8 @@
 #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
 
 #define GGML_MAX_DIMS          4
-#define GGML_MAX_NODES         4096
-#define GGML_MAX_PARAMS        256
+#define GGML_MAX_NODES         16384
+#define GGML_MAX_PARAMS        1024
 #define GGML_MAX_CONTEXTS      64
 #define GGML_MAX_SRC           6
 #define GGML_MAX_NAME          64
@@ -526,7 +526,15 @@ extern "C" {
     // next prime after GGML_MAX_NODES
     // #define GGML_GRAPH_HASHTABLE_SIZE 4099
     // next prime after GGML_MAX_NODES * 2 (nodes + leafs)
-    #define GGML_GRAPH_HASHTABLE_SIZE 8273
+    // #define GGML_GRAPH_HASHTABLE_SIZE 8273
+    // #define GGML_GRAPH_HASHTABLE_SIZE 16411
+    #define GGML_GRAPH_HASHTABLE_SIZE 32771
+
+    enum ggml_cgraph_eval_order {
+        GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0,
+        GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT,
+        GGML_CGRAPH_EVAL_ORDER_COUNT
+    };
 
     // computation graph
     struct ggml_cgraph {
@@ -539,6 +547,8 @@ extern "C" {
 
         void * visited_hash_table[GGML_GRAPH_HASHTABLE_SIZE];
 
+        enum ggml_cgraph_eval_order order;
+
         // performance
         int     perf_runs;
         int64_t perf_cycles;
@@ -686,12 +696,21 @@ extern "C" {
     GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
     GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
 
+    // Converts a flat index into coordinates
+    GGML_API void    ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
+
     GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
     GGML_API void    ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
 
+    GGML_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
+    GGML_API void    ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);
+
     GGML_API float   ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
     GGML_API void    ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
 
+    GGML_API float   ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
+    GGML_API void    ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);
+
     GGML_API void *  ggml_get_data    (const struct ggml_tensor * tensor);
     GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
 
@@ -725,6 +744,12 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    GGML_API struct ggml_tensor * ggml_add_cast(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            enum   ggml_type      type);
+
     GGML_API struct ggml_tensor * ggml_add1(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -834,6 +859,7 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    // sums repetitions in a into shape of b
     GGML_API struct ggml_tensor * ggml_repeat_back(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -1689,6 +1715,16 @@ extern "C" {
     // dump the graph into a file using the dot format
     GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
 
+    // build gradient checkpointing backward graph gb for gf using provided checkpoints
+    // gb_tmp will contain original backward graph with rewritten backward process nodes,
+    // but without the second forward pass nodes.
+    GGML_API void ggml_build_backward_gradient_checkpointing(
+            struct ggml_context   * ctx,
+            struct ggml_cgraph    * gf,
+            struct ggml_cgraph    * gb,
+            struct ggml_cgraph    * gb_tmp,
+            struct ggml_tensor  * * checkpoints,
+            int                     n_checkpoints);
     //
     // optimization
     //
@@ -1723,7 +1759,7 @@ extern "C" {
         GGML_LINESEARCH_INVALID_PARAMETERS,
     };
 
-    typedef void (*ggml_opt_callback)(void * data, float * sched);
+    typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel);
     typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
 
     // optimization parameters
@@ -1755,6 +1791,8 @@ extern "C" {
         bool print_forward_graph;
         bool print_backward_graph;
 
+        int n_gradient_accumulation;
+
         // ADAM parameters
         struct {
             int n_iter;
@@ -1800,6 +1838,7 @@ extern "C" {
         float loss_after;
 
         struct {
+            struct ggml_tensor * g;  // current gradient
             struct ggml_tensor * m;  // first moment
             struct ggml_tensor * v;  // second moment
             struct ggml_tensor * pf; // past function values

+ 13 - 8
llama.cpp

@@ -6298,7 +6298,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
 
 // TODO: after the GGUF PR, this likely won't work and needs to be updated
 static int llama_apply_lora_from_file_internal(
-    const struct llama_model & model, const char * path_lora, const char * path_base_model, int n_threads
+    const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads
 ) {
     LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora);
 
@@ -6327,7 +6327,7 @@ static int llama_apply_lora_from_file_internal(
     int32_t lora_alpha;
     fin.read((char *) &lora_r, sizeof(lora_r));
     fin.read((char *) &lora_alpha, sizeof(lora_alpha));
-    float scaling = (float)lora_alpha / (float)lora_r;
+    float scaling = scale * (float)lora_alpha / (float)lora_r;
 
     LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling);
 
@@ -6543,9 +6543,10 @@ static int llama_apply_lora_from_file_internal(
                 ggml_set_name(r, "r_cpy");
             }
 
-            struct ggml_cgraph gf = ggml_build_forward(r);
+            struct ggml_cgraph * gf = ggml_new_graph(lora_ctx);
+            ggml_build_forward_expand(gf, r);
 
-            ggml_graph_compute_helper(work_buffer, &gf, n_threads);
+            ggml_graph_compute_helper(work_buffer, gf, n_threads);
 
             // we won't need these tensors again, reset the context to save memory
             ggml_free(lora_ctx);
@@ -6926,6 +6927,10 @@ uint64_t llama_model_n_params(const struct llama_model * model) {
     return nparams;
 }
 
+struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name) {
+    return ggml_get_tensor(model->ctx, name);
+}
+
 int llama_model_quantize(
         const char * fname_inp,
         const char * fname_out,
@@ -6939,18 +6944,18 @@ int llama_model_quantize(
     }
 }
 
-int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) {
+int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, float scale, const char * path_base_model, int n_threads) {
     try {
-        return llama_apply_lora_from_file_internal(ctx->model, path_lora, path_base_model, n_threads);
+        return llama_apply_lora_from_file_internal(ctx->model, path_lora, scale, path_base_model, n_threads);
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
         return 1;
     }
 }
 
-int llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, const char * path_base_model, int n_threads) {
+int llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, float scale, const char * path_base_model, int n_threads) {
     try {
-        return llama_apply_lora_from_file_internal(*model, path_lora, path_base_model, n_threads);
+        return llama_apply_lora_from_file_internal(*model, path_lora, scale, path_base_model, n_threads);
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
         return 1;

+ 8 - 3
llama.h

@@ -287,6 +287,9 @@ extern "C" {
     // Returns the total number of parameters in the model
     LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
 
+    // Get a llama model tensor
+    LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
+
     // Returns 0 on success
     LLAMA_API int llama_model_quantize(
             const char * fname_inp,
@@ -302,15 +305,17 @@ extern "C" {
     LLAMA_API DEPRECATED(int llama_apply_lora_from_file(
             struct llama_context * ctx,
                       const char * path_lora,
+                           float   scale,
                       const char * path_base_model,
                              int   n_threads),
             "use llama_model_apply_lora_from_file instead");
 
     LLAMA_API int llama_model_apply_lora_from_file(
             const struct llama_model * model,
-                          const char * path_lora,
-                          const char * path_base_model,
-                                 int   n_threads);
+                      const char * path_lora,
+                           float   scale,
+                      const char * path_base_model,
+                             int   n_threads);
 
     //
     // KV cache

+ 119 - 46
tests/test-grad0.cpp

@@ -251,18 +251,20 @@ static bool check_gradient(
         printf("GGML_N_THREADS = %d\n", n_threads);
     }
 
-    struct ggml_cgraph gf = ggml_build_forward (f);
-    struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+    struct ggml_cgraph * gf = ggml_build_forward_ctx(ctx0, f);
+    struct ggml_cgraph * gb = ggml_new_graph(ctx0);
+    *gb = *gf;
+    ggml_build_backward_expand(ctx0, gf, gb, false);
 
-    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+    ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
 
-    ggml_graph_reset  (&gf);
+    ggml_graph_reset  (gf);
     ggml_set_f32      (f->grad, 1.0f);
 
-    ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
+    ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
 
-    // ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot");
-    // ggml_graph_dump_dot(&gb, &gf,  "test-grad0-backward.dot");
+    // ggml_graph_dump_dot(gf, NULL, "test-grad0-forward.dot");
+    // ggml_graph_dump_dot(gb, gf,  "test-grad0-backward.dot");
 
     for (int i = 0; i < nargs; ++i) {
         const int nelements = ggml_nelements(x[i]);
@@ -273,13 +275,13 @@ static bool check_gradient(
             const float xp = x0 + eps;
             ggml_set_f32_1d(x[i], k, xp);
 
-            ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+            ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
 
             const double f0 = ggml_get_f32_1d(f, 0);
 
             ggml_set_f32_1d(x[i], k, xm);
 
-            ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+            ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
 
             const double f1 = ggml_get_f32_1d(f, 0);
             const double g0 = (f0 - f1)/(2.0*(double) eps);
@@ -287,10 +289,10 @@ static bool check_gradient(
             ggml_set_f32_1d(x[i], k, x0);
 
             // compute gradient using backward graph
-            ggml_graph_reset  (&gf);
+            ggml_graph_reset  (gf);
             ggml_set_f32      (f->grad, 1.0f);
 
-            ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
+            ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
 
             const double g1 = ggml_get_f32_1d(x[i]->grad, k);
 
@@ -373,7 +375,7 @@ static bool check_mat_mul(
 
 int main(int argc, const char ** argv) {
     struct ggml_init_params params = {
-        /* .mem_size   = */ 128*1024*1024,
+        /* .mem_size   = */ 256*1024*1024,
         /* .mem_buffer = */ NULL,
         /* .no_alloc   = */ false,
     };
@@ -405,6 +407,7 @@ int main(int argc, const char ** argv) {
         }
     }
 
+    unsigned seed_iter = 1;
 
     // original loop: 1000
     int niter = 4;
@@ -416,6 +419,10 @@ int main(int argc, const char ** argv) {
         niter = atoi(argv[1]);
     }
     for (int iter = 0; iter < niter; ++iter) {
+        srand(seed_iter);
+        seed_iter = rand();
+        unsigned seed = rand();
+
         printf("test-grad0: iter:%d/%d\n", iter, niter);
         struct ggml_context * ctx0 = ggml_init(params);
 
@@ -425,6 +432,7 @@ int main(int argc, const char ** argv) {
 
         // add f32
         {
+            srand(seed);
             const int nargs = 2;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -441,6 +449,7 @@ int main(int argc, const char ** argv) {
 
         // add f16
         {
+            srand(seed);
             const int nargs = 2;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -457,6 +466,7 @@ int main(int argc, const char ** argv) {
 
         // sub
         {
+            srand(seed);
             const int nargs = 2;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -473,6 +483,7 @@ int main(int argc, const char ** argv) {
 
         // mul
         {
+            srand(seed);
             const int nargs = 2;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -489,6 +500,7 @@ int main(int argc, const char ** argv) {
 
         // div
         {
+            srand(seed);
             const int nargs = 2;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -505,6 +517,7 @@ int main(int argc, const char ** argv) {
 
         // sqr
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
@@ -521,6 +534,7 @@ int main(int argc, const char ** argv) {
 
         // sqrt
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
@@ -537,6 +551,7 @@ int main(int argc, const char ** argv) {
 
         // log
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
@@ -553,6 +568,7 @@ int main(int argc, const char ** argv) {
 
         // sum
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
@@ -570,6 +586,7 @@ int main(int argc, const char ** argv) {
 
         // sum_rows
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -587,6 +604,7 @@ int main(int argc, const char ** argv) {
         // mean, not yet fully implemented
         if(0)
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -604,6 +622,7 @@ int main(int argc, const char ** argv) {
         // argmax
         if (0)
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -620,6 +639,7 @@ int main(int argc, const char ** argv) {
 
         // repeat
         {
+            srand(seed);
             int64_t ne2[4];
             get_random_dims(ne2, 4);
 
@@ -642,6 +662,7 @@ int main(int argc, const char ** argv) {
 
         // repeat back
         {
+            srand(seed);
             int64_t ne2[4];
             get_random_dims(ne2, 4);
 
@@ -680,6 +701,7 @@ int main(int argc, const char ** argv) {
 
         // sgn
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -696,6 +718,7 @@ int main(int argc, const char ** argv) {
 
         // neg
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -712,6 +735,7 @@ int main(int argc, const char ** argv) {
 
         // step
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -729,6 +753,7 @@ int main(int argc, const char ** argv) {
         // tanh, not yet fully implemented
         if(0)
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -745,33 +770,45 @@ int main(int argc, const char ** argv) {
 
         // mul_mat
         {
+            srand(seed);
             const int nargs = 2;
 
-            for (int ndims = 2; ndims <= 2; ++ndims) {
+            for (int ndims = 2; ndims <= 4; ++ndims) {
+                int max_nrep = (ndims >= 3) ? 2 : 1;
                 x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                {
-                    int64_t ne2[4];
-                    get_random_dims(ne2, 4);
-                    ne2[0] = ne[0];
-                    x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
-                }
+                for (int nrep2 = 1; nrep2 < max_nrep; ++nrep2) {
+                    for (int nrep3 = 1; nrep3 < max_nrep; ++nrep3) {
+                        {
+                            int64_t ne2[4];
+                            get_random_dims(ne2, 4);
+                            ne2[0] = ne[0];
+                            ne2[2] = nrep2 * ne[2];
+                            ne2[3] = nrep3 * ne[3];
+                            x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
+                        }
 
-                ggml_set_param(ctx0, x[0]);
-                ggml_set_param(ctx0, x[1]);
+                        ggml_set_param(ctx0, x[0]);
+                        ggml_set_param(ctx0, x[1]);
 
-                struct ggml_tensor * m = ggml_mul_mat(ctx0, x[1], x[0]);
-                struct ggml_tensor * f = ggml_sum(ctx0, m);
+                        struct ggml_tensor * m = ggml_mul_mat(ctx0, x[1], x[0]);
+                        struct ggml_tensor * f = ggml_sum(ctx0, m);
 
-                GGML_PRINT_DEBUG("testing: mul_mat, [%lld, %lld] (%d) * [%lld, %lld] (%d)\n", x[1]->ne[0], x[1]->ne[1], x[1]->n_dims, x[0]->ne[0], x[0]->ne[1], x[0]->n_dims);
+                        GGML_PRINT_DEBUG("testing: mul_mat, [%lld, %lld] (%d) * [%lld, %lld] (%d)\n", x[1]->ne[0], x[1]->ne[1], x[1]->n_dims, x[0]->ne[0], x[0]->ne[1], x[0]->n_dims);
 
-                check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-                check_mat_mul(m, x[1], x[0]);
+                        check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                        if (ndims == 2) {
+                            // check_mat_mul does not support ndims > 2
+                            check_mat_mul(m, x[1], x[0]);
+                        }
+                    }
+                }
             }
         }
 
         // elu, not yet fully implemented
         if(0)
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -788,6 +825,7 @@ int main(int argc, const char ** argv) {
 
         // relu
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -805,6 +843,7 @@ int main(int argc, const char ** argv) {
         // gelu, not yet fully implemented
         if(0)
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -821,6 +860,7 @@ int main(int argc, const char ** argv) {
 
         // silu
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
@@ -842,6 +882,7 @@ int main(int argc, const char ** argv) {
 
         // rms_norm
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
@@ -858,6 +899,7 @@ int main(int argc, const char ** argv) {
 
         // scale
         {
+            srand(seed);
             const int nargs = 2;
 
             int64_t ne2[4];
@@ -878,6 +920,7 @@ int main(int argc, const char ** argv) {
 
         // cpy f32
         {
+            srand(seed);
             const int nargs = 2;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
@@ -895,6 +938,7 @@ int main(int argc, const char ** argv) {
 
         // cpy f16
         {
+            srand(seed);
             const int nargs = 2;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
@@ -912,6 +956,7 @@ int main(int argc, const char ** argv) {
 
         // reshape (1d->nd)
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
@@ -935,6 +980,7 @@ int main(int argc, const char ** argv) {
 
         // reshape (nd->1d)
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
@@ -958,6 +1004,7 @@ int main(int argc, const char ** argv) {
 
         // acc 1d
         {
+            srand(seed);
             int64_t ne2[4] = { 1, 1, 1, 1 };
 
             const int nargs = 2;
@@ -985,6 +1032,7 @@ int main(int argc, const char ** argv) {
 
         // acc 2d
         {
+            srand(seed);
             int64_t ne2[4]         = { 1, 1, 1, 1 };
             int64_t max_offsets[4] = { 0, 0, 0, 0 };
             int64_t offsets[4]     = { 0, 0, 0, 0 };
@@ -1017,6 +1065,7 @@ int main(int argc, const char ** argv) {
 
         // acc 3d
         {
+            srand(seed);
             int64_t ne2[4]         = { 1, 1, 1, 1 };
             int64_t max_offsets[4] = { 0, 0, 0, 0 };
             int64_t offsets[4]     = { 0, 0, 0, 0 };
@@ -1051,6 +1100,7 @@ int main(int argc, const char ** argv) {
 
         // acc 4d
         {
+            srand(seed);
             int64_t ne2[4]         = { 1, 1, 1, 1 };
             int64_t max_offsets[4] = { 0, 0, 0, 0 };
             int64_t offsets[4]     = { 0, 0, 0, 0 };
@@ -1087,6 +1137,7 @@ int main(int argc, const char ** argv) {
 
         // set_1d
         {
+            srand(seed);
             int64_t ne2[4];
 
             const int nargs = 2;
@@ -1114,6 +1165,7 @@ int main(int argc, const char ** argv) {
 
         // set_2d
         {
+            srand(seed);
             int64_t ne2[4];
             int64_t max_offsets[4] = { 0, 0, 0, 0 };
             int64_t offsets[4]     = { 0, 0, 0, 0 };
@@ -1146,6 +1198,7 @@ int main(int argc, const char ** argv) {
 
         // view_1d
         {
+            srand(seed);
             const int nargs = 1;
             for (int ndims = 1; ndims <= 4; ++ndims) {
 
@@ -1169,6 +1222,7 @@ int main(int argc, const char ** argv) {
 
         // view_2d
         {
+            srand(seed);
             int64_t ne2[4];
             int64_t nb2[4];
 
@@ -1199,6 +1253,7 @@ int main(int argc, const char ** argv) {
 
         // view_3d
         {
+            srand(seed);
             int64_t ne2[4] = {1,1,1,1};
             int64_t nb2[4] = {0,0,0,0};
 
@@ -1230,6 +1285,7 @@ int main(int argc, const char ** argv) {
 
         // permute
         {
+            srand(seed);
             int64_t ne2[4];
 
             const int nargs = 1;
@@ -1263,6 +1319,7 @@ int main(int argc, const char ** argv) {
 
         // transpose
         {
+            srand(seed);
             int64_t ne2[4];
 
             const int nargs = 1;
@@ -1290,6 +1347,7 @@ int main(int argc, const char ** argv) {
 
         // get_rows
         {
+            srand(seed);
             int64_t ne2[4] = {ne[0], ne[1], 1, 1};
             int64_t ne3[4] = {1+irand(ne[1]), 1, 1, 1};
             const int nargs = 1;
@@ -1306,6 +1364,7 @@ int main(int argc, const char ** argv) {
 
         // diag_mask_inf
         {
+            srand(seed);
             const int nargs = 1;
             const int ndims = 2;
 
@@ -1321,6 +1380,7 @@ int main(int argc, const char ** argv) {
 
         // diag_mask_zero
         {
+            srand(seed);
             const int nargs = 1;
             const int ndims = 2;
 
@@ -1336,6 +1396,7 @@ int main(int argc, const char ** argv) {
 
         // softmax
         {
+            srand(seed);
             const int nargs = 1;
 
             int64_t ne2[4];
@@ -1357,11 +1418,16 @@ int main(int argc, const char ** argv) {
                                                     ggml_new_f32(ctx0, eps))));
 
                 check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY);
+                // NOTE: softmax forward is computed using f16 table lookup instead of using actual expf, but backward assumes actual expf.
+                // this may result in different gradients too finite differences.
+                // when this test reports errors, first try to replace the table lookup with actual expf and test again to see if just that was the cause.
+                // if only the table lookup causes gradients to differ this is acceptable.
             }
         }
 
         // cross_entropy_loss
         {
+            srand(seed);
             const int nargs = 1;
 
             int64_t ne2[4];
@@ -1392,6 +1458,7 @@ int main(int argc, const char ** argv) {
 
         // rope f32
         {
+            srand(seed);
             const int nargs = 1;
 
             int64_t ne2[4];
@@ -1431,6 +1498,7 @@ int main(int argc, const char ** argv) {
 
         // rope f16
         {
+            srand(seed);
             const int nargs = 1;
 
             int64_t ne2[4];
@@ -1470,6 +1538,7 @@ int main(int argc, const char ** argv) {
 
         // flash_attn f32
         {
+            srand(seed);
             const int nargs = 3;
 
             int64_t ne2[4];
@@ -1482,28 +1551,31 @@ int main(int argc, const char ** argv) {
 
             for (int masked = 0; masked <= 1; ++masked) {
                 for (int ndims = 2; ndims <= 4; ++ndims) {
-                    int64_t neq[4] = { D, N, B, ne[3] };
-                    int64_t nek[4] = { D, M, B, ne[3] };
-                    int64_t nev[4] = { M, D, B, ne[3] };
-                    if (ndims == 2) {
-                        neq[2] = 1; neq[3] = 1;
-                        nek[2] = 1; nek[3] = 1;
-                        nev[2] = 1; nev[3] = 1;
-                    } else if (ndims == 3) {
-                        neq[3] = 1;
-                        nek[3] = 1;
-                        nev[3] = 1;
-                    }
-                    x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
-                    x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
-                    x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
-                    ggml_set_param(ctx0, x[0]);
-                    ggml_set_param(ctx0, x[1]);
-                    ggml_set_param(ctx0, x[2]);
+                    int max_nrep = (ndims >= 3) ? 2 : 1;
+                    for (int nrep = 1; nrep < max_nrep; ++nrep) {
+                        int64_t neq[4] = { D, N, B*nrep, ne[3] };
+                        int64_t nek[4] = { D, M, B, ne[3] };
+                        int64_t nev[4] = { M, D, B, ne[3] };
+                        if (ndims == 2) {
+                            neq[2] = 1; neq[3] = 1;
+                            nek[2] = 1; nek[3] = 1;
+                            nev[2] = 1; nev[3] = 1;
+                        } else if (ndims == 3) {
+                            neq[3] = 1;
+                            nek[3] = 1;
+                            nev[3] = 1;
+                        }
+                        x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
+                        x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
+                        x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
+                        ggml_set_param(ctx0, x[0]);
+                        ggml_set_param(ctx0, x[1]);
+                        ggml_set_param(ctx0, x[2]);
 
-                    struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
+                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
 
-                    check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
+                        check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
+                    }
                 }
             }
         }
@@ -1511,6 +1583,7 @@ int main(int argc, const char ** argv) {
         // flash_attn f16, not yet fully implemented
         if(0)
         {
+            srand(seed);
             const int nargs = 3;
 
             int64_t ne2[4];

Beberapa file tidak ditampilkan karena terlalu banyak file yang berubah dalam diff ini