| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271 |
- #include "ggml.h"
- #include "ggml-alloc.h"
- #include "common.h"
- #include "llama.h"
- #include <unordered_map>
- #include <vector>
- #include <cassert>
- #include <climits>
- #include <cstring>
- #include <cstdarg>
- #include <ctime>
- #include <random>
- #include <stdexcept>
- #include <algorithm>
- #include <string>
- #if defined(_MSC_VER)
- #pragma warning(disable: 4244 4267) // possible loss of data
- #endif
- struct random_normal_distribution {
- std::mt19937 gen;
- std::normal_distribution<float> rd;
- float min;
- float max;
- };
- struct random_uniform_distribution {
- std::mt19937 gen;
- std::uniform_real_distribution<float> rd;
- };
- void init_random_normal_distribution(struct random_normal_distribution * rnd, int seed, float mean, float std, float min, float max) {
- rnd->gen = std::mt19937(seed);
- rnd->rd = std::normal_distribution<float>{mean, std};
- rnd->min = min;
- rnd->max = max;
- }
- void init_random_uniform_distribution(struct random_uniform_distribution * rnd, int seed, float min, float max) {
- rnd->gen = std::mt19937(seed);
- rnd->rd = std::uniform_real_distribution<float>{min, max};
- }
- int clamp(const int v, const int min, const int max) {
- return ((v < min) ? (min) : (v > max) ? (max) : v);
- }
- float fclamp(const float v, const float min, const float max) {
- return ((v < min) ? (min) : (v > max) ? (max) : v);
- }
- float frand() {
- return (float)rand()/(float)RAND_MAX;
- }
- float frand_normal(struct random_normal_distribution * rnd) {
- return fclamp(rnd->rd(rnd->gen), rnd->min, rnd->max);
- }
- float frand_uniform(struct random_uniform_distribution * rnd) {
- return rnd->rd(rnd->gen);
- }
- struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) {
- float scale = 1.0f; // xavier
- switch (tensor->n_dims) {
- case 1:
- scale /= sqrtf(tensor->ne[0]);
- for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
- float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
- *dst = scale * frand_normal(rnd);
- }
- break;
- case 2:
- scale /= sqrtf(tensor->ne[0]+tensor->ne[1]);
- for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
- for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
- float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
- *dst = scale * frand_normal(rnd);
- }
- }
- break;
- case 3:
- scale /= sqrtf(tensor->ne[0]+tensor->ne[1]);
- for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
- for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
- for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
- float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
- *dst = scale * frand_normal(rnd);
- }
- }
- }
- break;
- case 4:
- scale /= sqrtf(tensor->ne[0]+tensor->ne[1]);
- for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
- for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
- for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
- for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
- float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
- *dst = scale * frand_normal(rnd);
- }
- }
- }
- }
- break;
- default:
- assert(false);
- };
- return tensor;
- }
- struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd) {
- switch (tensor->n_dims) {
- case 1:
- for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
- float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
- *dst = frand_uniform(rnd);
- }
- break;
- case 2:
- for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
- for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
- float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
- *dst = frand_uniform(rnd);
- }
- }
- break;
- case 3:
- for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
- for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
- for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
- float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
- *dst = frand_uniform(rnd);
- }
- }
- }
- break;
- case 4:
- for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
- for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
- for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
- for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
- float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
- *dst = frand_uniform(rnd);
- }
- }
- }
- }
- break;
- default:
- assert(false);
- };
- return tensor;
- }
- struct my_llama_hparams {
- uint32_t n_vocab = 32000;
- uint32_t n_ctx = 512;
- uint32_t n_embd = 4096;
- uint32_t n_head = 32;
- uint32_t n_layer = 32;
- uint32_t n_rot = 64;
- uint32_t n_ff = 11008;
- // float f_norm_eps = 1e-5; // falcon
- float f_norm_rms_eps = 1e-5; // llama
- float rope_freq_base = 10000.0f;
- float rope_freq_scale = 1.0f;
- };
- struct my_llama_layer {
- // normalization
- struct ggml_tensor * attention_norm;
- // attention
- struct ggml_tensor * wq;
- struct ggml_tensor * wk;
- struct ggml_tensor * wv;
- struct ggml_tensor * wo;
- // normalization
- struct ggml_tensor * ffn_norm;
- // ff
- struct ggml_tensor * w1;
- struct ggml_tensor * w2;
- struct ggml_tensor * w3;
- };
- struct my_llama_model {
- struct ggml_context * ctx = NULL;
- my_llama_hparams hparams;
- struct ggml_tensor * tok_embeddings;
- struct ggml_tensor * norm;
- struct ggml_tensor * output;
- std::vector<my_llama_layer> layers;
- uint32_t train_its = 0;
- uint32_t train_samples = 0;
- uint32_t train_tokens = 0;
- };
- // gguf constants
- const char * LLM_KV_OPTIMIZER_TYPE = "optimizer.type";
- const char * LLM_KV_OPTIMIZER_TYPE_ADAM = "adam";
- const char * LLM_KV_OPTIMIZER_TYPE_LBFGS = "lbfgs";
- const char * LLM_KV_OPTIMIZER_FILE_VERSION = "optimizer.file_version";
- const char * LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT = "optimizer.convergence_past_count";
- const char * LLM_KV_OPTIMIZER_PARAMETER_COUNT = "optimizer.parameter_count";
- const char * LLM_KV_OPTIMIZER_ITERATION_COUNT = "optimizer.iteration_count";
- const char * LLM_KV_OPTIMIZER_JUST_INITIALIZED = "optimizer.just_initialized";
- const char * LLM_KV_OPTIMIZER_ADAM_BEST_LOSS = "optimizer.adam.best_loss";
- const char * LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS = "optimizer.adam.previous_loss";
- const char * LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT = "optimizer.adam.no_improvement_count";
- const char * LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT = "optimizer.lbfgs.approx_hessian_count";
- const char * LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS = "optimizer.lbfgs.best_loss";
- const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP = "optimizer.lbfgs.line_search_step";
- const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J = "optimizer.lbfgs.line_search_j";
- const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K = "optimizer.lbfgs.line_search_k";
- const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END = "optimizer.lbfgs.line_search_end";
- const char * LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT = "optimizer.lbfgs.no_improvement_count";
- const char * LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS = "optimizer.adam.first_moments";
- const char * LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS = "optimizer.adam.second_moments";
- const char * LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES = "optimizer.adam.past_loss_values";
- const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS = "optimizer.lbfgs.current_parameters";
- const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS = "optimizer.lbfgs.previous_parameters";
- const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS = "optimizer.lbfgs.current_gradients";
- const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS = "optimizer.lbfgs.previous_gradients";
- const char * LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION = "optimizer.lbfgs.search_direction";
- const char * LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES = "optimizer.lbfgs.past_loss_values";
- const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA = "optimizer.lbfgs.memory_alpha";
- const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.lbfgs.memory_ys";
- const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s";
- const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y";
- const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version";
- const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count";
- const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count";
- const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count";
- // gguf constants (sync with gguf.py)
- const char * LLM_KV_GENERAL_ARCHITECTURE = "general.architecture";
- const char * LLM_KV_GENERAL_FILE_TYPE = "general.file_type";
- const char * LLM_KV_CONTEXT_LENGTH = "%s.context_length";
- const char * LLM_KV_EMBEDDING_LENGTH = "%s.embedding_length";
- const char * LLM_KV_BLOCK_COUNT = "%s.block_count";
- const char * LLM_KV_FEED_FORWARD_LENGTH = "%s.feed_forward_length";
- const char * LLM_KV_ATTENTION_HEAD_COUNT = "%s.attention.head_count";
- const char * LLM_KV_ATTENTION_LAYERNORM_RMS_EPS = "%s.attention.layer_norm_rms_epsilon";
- const char * LLM_KV_ROPE_DIMENSION_COUNT = "%s.rope.dimension_count";
- const char * LLM_KV_ROPE_FREQ_BASE = "%s.rope.freq_base"; // TODO load in llama.cpp
- const char * LLM_KV_ROPE_SCALE_LINEAR = "%s.rope.scale_linear";
- const char * LLM_KV_TOKENIZER_MODEL = "tokenizer.ggml.model";
- const char * LLM_KV_TOKENIZER_LIST = "tokenizer.ggml.tokens";
- const char * LLM_KV_TOKENIZER_TOKEN_TYPE = "tokenizer.ggml.token_type";
- const char * LLM_KV_TOKENIZER_SCORES = "tokenizer.ggml.scores";
- const char * LLM_KV_TOKENIZER_MERGES = "tokenizer.ggml.merges";
- const char * LLM_KV_TOKENIZER_BOS_ID = "tokenizer.ggml.bos_token_id";
- const char * LLM_KV_TOKENIZER_EOS_ID = "tokenizer.ggml.eos_token_id";
- const char * LLM_KV_TOKENIZER_UNK_ID = "tokenizer.ggml.unknown_token_id";
- const char * LLM_KV_TOKENIZER_SEP_ID = "tokenizer.ggml.seperator_token_id";
- const char * LLM_KV_TOKENIZER_PAD_ID = "tokenizer.ggml.padding_token_id";
- const char * LLM_TENSOR_TOKEN_EMBD = "token_embd";
- const char * LLM_TENSOR_OUTPUT_NORM = "output_norm";
- const char * LLM_TENSOR_OUTPUT = "output";
- const char * LLM_TENSOR_ATTN_NORM = "blk.%d.attn_norm";
- const char * LLM_TENSOR_ATTN_Q = "blk.%d.attn_q";
- const char * LLM_TENSOR_ATTN_K = "blk.%d.attn_k";
- const char * LLM_TENSOR_ATTN_V = "blk.%d.attn_v";
- const char * LLM_TENSOR_ATTN_OUT = "blk.%d.attn_output";
- const char * LLM_TENSOR_FFN_NORM = "blk.%d.ffn_norm";
- const char * LLM_TENSOR_FFN_GATE = "blk.%d.ffn_gate";
- const char * LLM_TENSOR_FFN_DOWN = "blk.%d.ffn_down";
- const char * LLM_TENSOR_FFN_UP = "blk.%d.ffn_up";
- void print_params(struct my_llama_hparams * params) {
- printf("%s: n_vocab: %d\n", __func__, params->n_vocab);
- printf("%s: n_ctx: %d\n", __func__, params->n_ctx);
- printf("%s: n_embd: %d\n", __func__, params->n_embd);
- printf("%s: n_head: %d\n", __func__, params->n_head);
- printf("%s: n_ff: %d\n", __func__, params->n_ff);
- printf("%s: n_layer: %d\n", __func__, params->n_layer);
- printf("%s: n_rot: %d\n", __func__, params->n_rot);
- }
- void init_model(struct my_llama_model * model) {
- const auto & hparams = model->hparams;
- const uint32_t n_embd = hparams.n_embd;
- const uint32_t n_layer = hparams.n_layer;
- const uint32_t n_vocab = hparams.n_vocab;
- const uint32_t n_ff = hparams.n_ff;
- struct ggml_context * ctx = model->ctx;
- model->train_its = 0;
- model->train_samples = 0;
- model->train_tokens = 0;
- std::vector<char> tn_buf;
- tn_buf.resize(GGML_MAX_NAME);
- auto tn = [&tn_buf](const char * key) -> const char * {
- snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", key);
- return tn_buf.data();
- };
- auto tni = [&tn_buf](const char * key, int bid) -> const char * {
- snprintf(tn_buf.data(), tn_buf.size(), key, bid);
- std::string s = tn_buf.data();
- snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", s.c_str());
- return tn_buf.data();
- };
- model->tok_embeddings = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
- model->norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
- model->output = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
- ggml_set_name(model->tok_embeddings, tn(LLM_TENSOR_TOKEN_EMBD));
- ggml_set_name(model->norm, tn(LLM_TENSOR_OUTPUT_NORM));
- ggml_set_name(model->output, tn(LLM_TENSOR_OUTPUT));
- model->layers.resize(n_layer);
- for (uint32_t i = 0; i < n_layer; ++i) {
- auto & layer = model->layers[i];
- layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
- layer.wq = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
- layer.wk = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
- layer.wv = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
- layer.wo = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
- layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
- layer.w1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
- layer.w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd);
- layer.w3 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
- ggml_set_name(layer.attention_norm, tni(LLM_TENSOR_ATTN_NORM, i));
- ggml_set_name(layer.wq, tni(LLM_TENSOR_ATTN_Q, i));
- ggml_set_name(layer.wk, tni(LLM_TENSOR_ATTN_K, i));
- ggml_set_name(layer.wv, tni(LLM_TENSOR_ATTN_V, i));
- ggml_set_name(layer.wo, tni(LLM_TENSOR_ATTN_OUT, i));
- ggml_set_name(layer.ffn_norm, tni(LLM_TENSOR_FFN_NORM, i));
- ggml_set_name(layer.w1, tni(LLM_TENSOR_FFN_GATE, i));
- ggml_set_name(layer.w2, tni(LLM_TENSOR_FFN_DOWN, i));
- ggml_set_name(layer.w3, tni(LLM_TENSOR_FFN_UP, i));
- }
- }
- void set_param_model(struct my_llama_model * model) {
- const auto& hparams = model->hparams;
- const uint32_t n_layer = hparams.n_layer;
- struct ggml_context* ctx = model->ctx;
- ggml_set_param(ctx, model->tok_embeddings);
- ggml_set_param(ctx, model->norm);
- ggml_set_param(ctx, model->output);
- for (uint32_t i = 0; i < n_layer; ++i) {
- auto & layer = model->layers[i];
- ggml_set_param(ctx, layer.attention_norm);
- ggml_set_param(ctx, layer.wq);
- ggml_set_param(ctx, layer.wk);
- ggml_set_param(ctx, layer.wv);
- ggml_set_param(ctx, layer.wo);
- ggml_set_param(ctx, layer.ffn_norm);
- ggml_set_param(ctx, layer.w1);
- ggml_set_param(ctx, layer.w2);
- ggml_set_param(ctx, layer.w3);
- }
- }
- void randomize_model(struct my_llama_model * model, int seed, float mean, float std, float min, float max) {
- const auto & hparams = model->hparams;
- const uint32_t n_layer = hparams.n_layer;
- struct random_normal_distribution rnd;
- init_random_normal_distribution(&rnd, seed, mean, std, min, max);
- randomize_tensor_normal(model->tok_embeddings, &rnd);
- randomize_tensor_normal(model->norm, &rnd);
- randomize_tensor_normal(model->output, &rnd);
- for (uint32_t i = 0; i < n_layer; ++i) {
- auto & layer = model->layers[i];
- randomize_tensor_normal(layer.attention_norm, &rnd);
- randomize_tensor_normal(layer.wq, &rnd);
- randomize_tensor_normal(layer.wk, &rnd);
- randomize_tensor_normal(layer.wv, &rnd);
- randomize_tensor_normal(layer.wo, &rnd);
- randomize_tensor_normal(layer.ffn_norm, &rnd);
- randomize_tensor_normal(layer.w1, &rnd);
- randomize_tensor_normal(layer.w2, &rnd);
- randomize_tensor_normal(layer.w3, &rnd);
- }
- }
- void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) {
- GGML_ASSERT(tensor->n_dims == 1);
- GGML_ASSERT(tensor->ne[0] == ne0);
- }
- void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1) {
- GGML_ASSERT(tensor->n_dims == 2);
- GGML_ASSERT(tensor->ne[0] == ne0);
- GGML_ASSERT(tensor->ne[1] == ne1);
- }
- void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2) {
- GGML_ASSERT(tensor->n_dims == 3);
- GGML_ASSERT(tensor->ne[0] == ne0);
- GGML_ASSERT(tensor->ne[1] == ne1);
- GGML_ASSERT(tensor->ne[2] == ne2);
- }
- void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
- GGML_ASSERT(tensor->n_dims == 4);
- GGML_ASSERT(tensor->ne[0] == ne0);
- GGML_ASSERT(tensor->ne[1] == ne1);
- GGML_ASSERT(tensor->ne[2] == ne2);
- GGML_ASSERT(tensor->ne[3] == ne3);
- }
- static size_t hash(void * p) {
- return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
- }
- static size_t hash_find(void * hash_table[], void * p) {
- size_t h = hash(p);
- // linear probing
- size_t i = h;
- while (hash_table[i] != NULL && hash_table[i] != p) {
- i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
- if (i == h) {
- // visited all hash table entries -> not found
- return GGML_GRAPH_HASHTABLE_SIZE;
- }
- }
- return i;
- }
- static bool hash_insert(void * hash_table[], void * p) {
- //size_t h = hash(p);
- size_t i = hash_find(hash_table, p);
- GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
- if (hash_table[i] == p) {
- return true;
- }
- // insert
- GGML_ASSERT(hash_table[i] == NULL);
- hash_table[i] = p;
- return false;
- }
- static bool hash_contains(void * hash_table[], void * p) {
- size_t i = hash_find(hash_table, p);
- return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
- }
- struct hash_map {
- void * keys[GGML_GRAPH_HASHTABLE_SIZE];
- void * vals[GGML_GRAPH_HASHTABLE_SIZE];
- };
- //static const size_t HASH_MAP_SIZE = sizeof(struct hash_map);
- struct hash_map * new_hash_map() {
- struct hash_map * result = new struct hash_map;
- for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
- result->keys[i] = NULL;
- result->vals[i] = NULL;
- }
- return result;
- };
- void free_hash_map(struct hash_map * map) {
- delete map;
- }
- static bool ggml_is_view(struct ggml_tensor * t) {
- return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
- t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY;
- }
- static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) {
- switch (t->op) {
- case GGML_OP_PERMUTE:
- case GGML_OP_RESHAPE:
- case GGML_OP_TRANSPOSE:
- case GGML_OP_VIEW:
- return t->src[0];
- case GGML_OP_CPY:
- return t->src[1];
- default:
- return NULL;
- }
- }
- static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
- struct ggml_tensor * parent = t;
- do {
- parent = get_view_parent(parent);
- } while (ggml_is_view(parent));
- return parent;
- }
- struct ggml_tensor * ggml_recompute_graph_node(
- struct ggml_context * ctx,
- struct ggml_cgraph * graph,
- struct hash_map * replacements,
- struct ggml_tensor * node) {
- if (node == NULL) {
- return NULL;
- }
- if (node->is_param) {
- return node;
- }
- if (!hash_contains(graph->visited_hash_table, node)) {
- return node;
- }
- int count_children = 0;
- for (int k = 0; k < GGML_MAX_SRC; ++k) {
- if (node->src[k]) {
- ++count_children;
- }
- }
- if (count_children == 0) {
- return node;
- }
- size_t i = hash_find(replacements->keys, node);
- GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
- if (replacements->keys[i] == node) {
- return (struct ggml_tensor *) replacements->vals[i];
- }
- struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
- // insert clone into replacements
- GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
- replacements->keys[i] = node;
- replacements->vals[i] = clone;
- clone->op = node->op;
- clone->grad = node->grad;
- clone->is_param = node->is_param;
- clone->extra = node->extra;
- for (int k = 0; k < GGML_MAX_DIMS; ++k) {
- clone->nb[k] = node->nb[k];
- }
- for (int k = 0; k < GGML_MAX_SRC; ++k) {
- clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
- }
- if (ggml_is_view(clone)) {
- struct ggml_tensor * source = get_view_source(clone);
- GGML_ASSERT(source != NULL);
- clone->data = source->data;
- }
- GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
- GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
- memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
- ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
- return clone;
- };
- void ggml_build_backward_gradient_checkpointing(
- struct ggml_context * ctx,
- struct ggml_cgraph * gf,
- struct ggml_cgraph * gb,
- struct ggml_cgraph * gb_tmp,
- struct ggml_tensor * * checkpoints,
- int n_checkpoints) {
- *gb_tmp = *gf;
- ggml_build_backward_expand(ctx, gf, gb_tmp, true);
- if (n_checkpoints <= 0) {
- *gb = *gb_tmp;
- return;
- }
- struct hash_map * replacements = new_hash_map();
- // insert checkpoints in replacements
- for (int i = 0; i < n_checkpoints; ++i) {
- size_t k = hash_find(replacements->keys, checkpoints[i]);
- GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
- GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
- replacements->keys[k] = checkpoints[i];
- replacements->vals[k] = checkpoints[i];
- }
- *gb = *gf;
- // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
- // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
- // by recomputing them from checkpoints
- for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
- struct ggml_tensor * node = gb_tmp->nodes[i];
- for (int k = 0; k < GGML_MAX_SRC; ++k) {
- // insert new tensors recomputing src, reusing already made replacements,
- // remember replacements: remember new tensors with mapping from corresponding gf nodes
- // recurse for input tensors,
- // unless (i.e. terminating when) input tensors are checkpoints
- node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
- }
- // insert rewritten backward node with replacements made into resulting backward graph gb
- ggml_build_forward_expand(gb, node);
- }
- free_hash_map(replacements);
- }
- struct ggml_tensor * llama_build_train_graphs(
- struct my_llama_model * model,
- struct ggml_allocr * alloc,
- struct ggml_context * ctx,
- struct ggml_cgraph * gf,
- struct ggml_cgraph * gb,
- struct ggml_cgraph * gb_tmp,
- struct ggml_tensor * * logits,
- struct ggml_tensor * tokens_input,
- struct ggml_tensor * targets,
- const int n_tokens,
- const int n_batch,
- const bool enable_flash_attn,
- const bool enable_checkpointing) {
- ggml_set_scratch(ctx, { 0, 0, nullptr, });
- const int n_past = 0;
- const int N = n_tokens;
- const auto & hparams = model->hparams;
- const int n_ctx = hparams.n_ctx;
- const int n_vocab = hparams.n_vocab;
- const int n_embd = hparams.n_embd;
- const int n_layer = hparams.n_layer;
- const int n_head = hparams.n_head;
- const int n_rot = hparams.n_rot;
- const int n_ff = hparams.n_ff;
- const float f_norm_rms_eps = hparams.f_norm_rms_eps;
- const float rope_freq_base = hparams.rope_freq_base;
- const float rope_freq_scale = hparams.rope_freq_scale;
- auto set_name = [](struct ggml_tensor * t, const char * n) {
- ggml_set_name(t, n);
- if (t->grad) {
- ggml_format_name(t->grad, "%s->grad", n);
- }
- };
- // rope has so much parameters that we make a custom function for it
- auto rope = [ctx, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
- (struct ggml_tensor * t) -> struct ggml_tensor * {
- // not capturing these, to silcence warnings
- const int n_past = 0;
- const int rope_mode = 0;
- return ggml_rope_custom(ctx,
- t, n_past, n_rot, rope_mode, n_ctx,
- rope_freq_base, rope_freq_scale);
- };
- set_name(tokens_input, "tokens_input");
- set_name(targets, "targets");
- GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
- struct ggml_tensor * t00 = ggml_reshape_1d(ctx, tokens_input, N*n_batch); set_name(t00, "t00"); assert_shape_1d(t00, N*n_batch);
- struct ggml_tensor * t01 = ggml_get_rows(ctx, model->tok_embeddings, t00); set_name(t01, "t01"); assert_shape_2d(t01, n_embd, N*n_batch);
- struct ggml_tensor * cur = t01;
- std::vector<struct ggml_tensor *> checkpoints;
- checkpoints.push_back(tokens_input);
- checkpoints.push_back(targets);
- checkpoints.push_back(t00);
- checkpoints.push_back(t01);
- struct ggml_tensor * kv_scale;
- if (!enable_flash_attn) {
- kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head));
- }
- for (int il = 0; il < n_layer; ++il) {
- struct my_llama_layer & layer = model->layers[il];
- struct ggml_tensor * t02 = ggml_rms_norm (ctx, cur, f_norm_rms_eps); set_name(t02, "t02"); assert_shape_2d(t02, n_embd, N*n_batch);
- struct ggml_tensor * t03 = ggml_repeat (ctx, layer.attention_norm, t02); set_name(t03, "t03"); assert_shape_2d(t03, n_embd, N*n_batch);
- struct ggml_tensor * t04 = ggml_mul (ctx, t03, t02); set_name(t04, "t04"); assert_shape_2d(t04, n_embd, N*n_batch);
- struct ggml_tensor * t05 = ggml_mul_mat (ctx, layer.wq, t04); set_name(t05, "t05"); assert_shape_2d(t05, n_embd, N*n_batch);
- struct ggml_tensor * t06 = ggml_reshape_4d (ctx, t05, n_embd/n_head, n_head, N, n_batch); set_name(t06, "t06"); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch);
- struct ggml_tensor * t07 = rope (t06); set_name(t07, "t07"); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch);
- struct ggml_tensor * t08 = ggml_mul_mat (ctx, layer.wk, t04); set_name(t08, "t08"); assert_shape_2d(t08, n_embd, N*n_batch);
- struct ggml_tensor * t09 = ggml_reshape_4d (ctx, t08, n_embd/n_head, n_head, N, n_batch); set_name(t09, "t09"); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch);
- struct ggml_tensor * t10 = rope (t09); set_name(t10, "t10"); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch);
- struct ggml_tensor * t11 = ggml_mul_mat (ctx, t04, layer.wv); set_name(t11, "t11"); assert_shape_2d(t11, N*n_batch, n_embd);
- struct ggml_tensor * t12 = ggml_reshape_4d (ctx, t11, N, n_batch, n_embd/n_head, n_head); set_name(t12, "t12"); assert_shape_4d(t12, N, n_batch, n_embd/n_head, n_head);
- struct ggml_tensor * t13 = ggml_permute (ctx, t07, 0, 2, 1, 3); set_name(t13, "t13"); assert_shape_4d(t13, n_embd/n_head, N, n_head, n_batch);
- struct ggml_tensor * t14 = ggml_permute (ctx, t10, 0, 2, 1, 3); set_name(t14, "t14"); assert_shape_4d(t14, n_embd/n_head, N, n_head, n_batch);
- struct ggml_tensor * t15 = ggml_permute (ctx, t12, 0, 3, 1, 2); set_name(t15, "t15"); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch);
- struct ggml_tensor * t16;
- if (enable_flash_attn) {
- t16 = ggml_flash_attn(ctx, t13, t14, t15, true); set_name(t16, "t16"); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
- } else {
- struct ggml_tensor * t16_0 = ggml_mul_mat (ctx, t14, t13); set_name(t16_0, "t16_0"); assert_shape_4d(t16_0, N, N, n_head, n_batch);
- struct ggml_tensor * t16_1 = ggml_scale_inplace (ctx, t16_0, kv_scale); set_name(t16_1, "t16_1"); assert_shape_4d(t16_1, N, N, n_head, n_batch);
- struct ggml_tensor * t16_2 = ggml_diag_mask_inf_inplace(ctx, t16_1, n_past); set_name(t16_2, "t16_2"); assert_shape_4d(t16_2, N, N, n_head, n_batch);
- struct ggml_tensor * t16_3 = ggml_soft_max_inplace (ctx, t16_2); set_name(t16_3, "t16_3"); assert_shape_4d(t16_3, N, N, n_head, n_batch);
- t16 = ggml_mul_mat(ctx, t15, t16_3); set_name(t16, "t16"); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
- }
- struct ggml_tensor * t17 = ggml_permute (ctx, t16, 0, 2, 1, 3); set_name(t17, "t17"); assert_shape_4d(t17, n_embd/n_head, n_head, N, n_batch);
- struct ggml_tensor * t18 = ggml_cont (ctx, t17); set_name(t18, "t18"); assert_shape_4d(t18, n_embd/n_head, n_head, N, n_batch);
- struct ggml_tensor * t19 = ggml_reshape_2d (ctx, t18, n_embd, N*n_batch); set_name(t19, "t19"); assert_shape_2d(t19, n_embd, N*n_batch);
- struct ggml_tensor * t20 = ggml_mul_mat (ctx, layer.wo, t19); set_name(t20, "t20"); assert_shape_2d(t20, n_embd, N*n_batch);
- struct ggml_tensor * t21 = ggml_add (ctx, t20, cur); set_name(t21, "t21"); assert_shape_2d(t21, n_embd, N*n_batch);
- struct ggml_tensor * t22 = ggml_rms_norm (ctx, t21, f_norm_rms_eps); set_name(t22, "t22"); assert_shape_2d(t22, n_embd, N*n_batch);
- struct ggml_tensor * t23 = ggml_repeat (ctx, layer.ffn_norm, t22); set_name(t23, "t23"); assert_shape_2d(t23, n_embd, N*n_batch);
- struct ggml_tensor * t24 = ggml_mul (ctx, t23, t22); set_name(t24, "t24"); assert_shape_2d(t24, n_embd, N*n_batch);
- struct ggml_tensor * t25 = ggml_mul_mat (ctx, layer.w3, t24); set_name(t25, "t25"); assert_shape_2d(t25, n_ff, N*n_batch);
- struct ggml_tensor * t26 = ggml_mul_mat (ctx, layer.w1, t24); set_name(t26, "t26"); assert_shape_2d(t26, n_ff, N*n_batch);
- struct ggml_tensor * t27 = ggml_silu (ctx, t26); set_name(t27, "t27"); assert_shape_2d(t27, n_ff, N*n_batch);
- struct ggml_tensor * t28 = ggml_mul (ctx, t27, t25); set_name(t28, "t28"); assert_shape_2d(t28, n_ff, N*n_batch);
- struct ggml_tensor * t29 = ggml_mul_mat (ctx, layer.w2, t28); set_name(t29, "t29"); assert_shape_2d(t29, n_embd, N*n_batch);
- struct ggml_tensor * t30 = ggml_add (ctx, t29, t21); set_name(t30, "t30"); assert_shape_2d(t30, n_embd, N*n_batch);
- cur = t30;
- checkpoints.push_back(cur);
- }
- struct ggml_tensor * t31 = ggml_rms_norm (ctx, cur, f_norm_rms_eps); set_name(t31, "t31"); assert_shape_2d(t31, n_embd, N*n_batch);
- struct ggml_tensor * t32 = ggml_repeat (ctx, model->norm, t31); set_name(t32, "t32"); assert_shape_2d(t32, n_embd, N*n_batch);
- struct ggml_tensor * t33 = ggml_mul (ctx, t32, t31); set_name(t33, "t33"); assert_shape_2d(t33, n_embd, N*n_batch);
- struct ggml_tensor * t34 = ggml_mul_mat (ctx, model->output, t33); set_name(t34, "t34"); assert_shape_2d(t34, n_vocab, N*n_batch);
- struct ggml_tensor * t35 = ggml_reshape_3d (ctx, t34, n_vocab, N, n_batch); set_name(t35, "t35"); assert_shape_3d(t35, n_vocab, N, n_batch);
- struct ggml_tensor * t36 = ggml_cross_entropy_loss(ctx, t35, targets); set_name(t36, "t36"); assert_shape_1d(t36, 1);
- checkpoints.push_back(t31);
- checkpoints.push_back(t32);
- checkpoints.push_back(t33);
- checkpoints.push_back(t34);
- checkpoints.push_back(t35);
- checkpoints.push_back(t36);
- ggml_build_forward_expand(gf, t36);
- if (enable_checkpointing) {
- ggml_build_backward_gradient_checkpointing(ctx, gf, gb, gb_tmp, checkpoints.data(), (int) checkpoints.size());
- } else {
- *gb = *gf;
- ggml_build_backward_expand(ctx, gf, gb, true);
- }
- if (alloc) {
- // make sure some tensors are not reallocated by inserting new temporary nodes depending on them
- int n_leafs_before = gb->n_leafs;
- int n_nodes_before = gb->n_nodes;
- struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f);
- // output tensors
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one));
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
- // input gradient
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
- GGML_ASSERT(t36->grad->data == NULL && !ggml_is_view(t36->grad));
- ggml_allocr_alloc(alloc, t36->grad);
- // gradient tensors (will be set to zero by ggml_graph_reset)
- // pinning these produces large unnecessary memory overhead, which will be resolved by PR 2632
- for (int i = 0; i < gf->n_nodes; ++i) {
- if (!gf->grads[i]) continue;
- if (gf->grads[i]->data == NULL && !ggml_is_view(gf->grads[i])) {
- ggml_allocr_alloc(alloc, gf->grads[i]);
- }
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, gf->grads[i], one));
- }
- // allocating checkpoints in one block to reduce memory fragmentation
- // note: they will be freed in reverse order
- for (int i = 0; i < (int) checkpoints.size(); ++i) {
- if (checkpoints[i]->data == NULL && !ggml_is_view(checkpoints[i])) {
- ggml_allocr_alloc(alloc, checkpoints[i]);
- }
- }
- //int n_leafs_after = gb->n_leafs;
- //int n_nodes_after = gb->n_nodes;
- ggml_allocr_alloc_graph(alloc, gb);
- // remove the additional nodes and leafs
- for (int i = n_leafs_before; i < gb->n_leafs; ++i) {
- gb->leafs[i] = NULL;
- }
- for (int i = n_nodes_before; i < gb->n_nodes; ++i) {
- gb->nodes[i] = NULL;
- }
- gb->n_leafs = n_leafs_before;
- gb->n_nodes = n_nodes_before;
- }
- *logits = t35;
- return t36;
- }
- void set_f32_3d(struct ggml_tensor * tensor, int64_t i0, int64_t i1, int64_t i2, float value) {
- float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
- *ptr = value;
- }
- void set_f32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1, float value) {
- float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
- *ptr = value;
- }
- void set_i32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1, int32_t value) {
- int32_t * ptr = (int32_t *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
- *ptr = value;
- }
- float get_f32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) {
- float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
- return *ptr;
- }
- int32_t get_i32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) {
- int32_t * ptr = (int32_t *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
- return *ptr;
- }
- void print_row(struct ggml_tensor * probs, int i) {
- for (int k = 0; k < probs->ne[0]; ++k) {
- float p = get_f32_2d(probs, k, i);
- printf(" %.2f", p);
- }
- printf("\n");
- }
- void print_matrix(struct ggml_tensor * probs) {
- assert(probs->n_dims == 2);
- for (int i = 0; i < probs->ne[1]; ++i) {
- for (int k = 0; k < probs->ne[0]; ++k) {
- float p = get_f32_2d(probs, k, i);
- printf(" %.2f", p);
- }
- printf("\n");
- }
- }
- void get_example_targets(struct llama_context * lctx, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
- int n_tokens = tokens_input->ne[0];
- int n_vocab = target_logits->ne[0];
- size_t sample = train_samples[example_id % n_train_samples];
- GGML_ASSERT(sample+n_tokens-1 < n_train_data);
- ggml_set_f32(target_logits, -1.0f/n_vocab);
- ggml_set_f32(target_probs, 0.0f);
- ggml_set_i32_1d(tokens_input, 0, llama_token_bos(lctx));
- for (int i=1; i<n_tokens+1; ++i) {
- int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
- set_f32_2d(target_logits, token, i-1, +1.0f);
- set_f32_2d(target_probs, token, i-1, +1.0f);
- if (i<n_tokens) {
- ggml_set_i32_1d(tokens_input, i, token);
- }
- }
- }
- void get_example_targets_batch(struct llama_context * lctx, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
- GGML_ASSERT(tokens_input->n_dims == 2);
- GGML_ASSERT(target_logits->n_dims == 3);
- GGML_ASSERT(target_probs->n_dims == 3);
- int n_vocab = target_logits->ne[0];
- int n_tokens = tokens_input->ne[0];
- int n_batch = tokens_input->ne[1];
- GGML_ASSERT(n_tokens == target_logits->ne[1]);
- GGML_ASSERT(n_batch == target_logits->ne[2]);
- GGML_ASSERT(n_vocab == target_probs->ne[0]);
- GGML_ASSERT(n_tokens == target_probs->ne[1]);
- GGML_ASSERT(n_batch == target_probs->ne[2]);
- ggml_set_f32(target_logits, -1.0f/n_vocab);
- ggml_set_f32(target_probs, 0.0f);
- // printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
- for (int k=0; k<n_batch; ++k) {
- // printf("%s: batch %d\n", __func__, k);
- size_t sample_idx = (example_id*n_batch + k) % n_train_samples;
- size_t sample = train_samples[sample_idx];
- // printf("%s: sample_idx=%zu sample=%zu\n", __func__, sample_idx, sample);
- GGML_ASSERT(sample+n_tokens-1 < n_train_data);
- set_i32_2d(tokens_input, 0, k, llama_token_bos(lctx));
- for (int i=1; i<n_tokens+1; ++i) {
- int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
- set_f32_3d(target_logits, token, i-1, k, +1.0f);
- set_f32_3d(target_probs, token, i-1, k, +1.0f);
- if (i<n_tokens) {
- set_i32_2d(tokens_input, i, k, token);
- }
- }
- }
- }
- int tokenize_file(struct llama_context * lctx, const char * filename, std::vector<llama_token>& out) {
- FILE * fp = std::fopen(filename, "rb");
- if (fp == NULL) {
- return 0;
- }
- #ifdef _WIN32
- GGML_ASSERT(_fseeki64(fp, (__int64) 0, SEEK_END) == 0);
- #else
- GGML_ASSERT(std::fseek(fp, (long) 0, SEEK_END) == 0);
- #endif
- size_t size = 0;
- #ifdef _WIN32
- __int64 ret = _ftelli64(fp);
- size = ret;
- #else
- long ret = std::ftell(fp);
- size = ret;
- #endif
- #ifdef _WIN32
- GGML_ASSERT(_fseeki64(fp, (__int64) 0, SEEK_SET) == 0);
- #else
- GGML_ASSERT(std::fseek(fp, (long) 0, SEEK_SET) == 0);
- #endif
- std::vector<char> buf;
- buf.resize(size+1);
- out.resize(size+1);
- if (std::fread(buf.data(), size, 1, fp) != 1) {
- die("unexpectedly reached end of file");
- }
- if (ferror(fp)) {
- die_fmt("fread failed: %s", strerror(errno));
- }
- buf[size] = '\0';
- int n_tokens = llama_tokenize(lctx, buf.data(), buf.size(), out.data(), out.size(), false);
- if (n_tokens < 0) {
- out.resize(-n_tokens);
- n_tokens = llama_tokenize(lctx, buf.data(), buf.size(), out.data(), out.size(), false);
- }
- GGML_ASSERT(n_tokens >= 0);
- out.resize(n_tokens);
- bool verify = false;
- if (verify) {
- const char * in = buf.data();
- const char * end = buf.data() + buf.size();
- for (int i = 0; i < (int) out.size(); ++i) {
- std::string s = llama_token_to_piece(lctx, out[i]);
- int len = s.length();
- if (in >= end) {
- printf("%s: unexpected end of original text.\n", __func__);
- break;
- }
- const bool matches = (strncmp(in, s.c_str(), len) == 0);
- if (matches) {
- in += len;
- } else {
- printf("%s: mismatch: expected '%s', but got '%s'\n", __func__, std::string(in, len).c_str(), s.c_str());
- }
- }
- }
- return n_tokens;
- }
- void shuffle_ints(int * begin, int * end) {
- if (end <= begin) return;
- int max=begin[0];
- for (int i=1; i<end-begin; ++i) {
- if (begin[i] > max) {
- max = begin[i];
- }
- }
- std::vector<float> vals;
- vals.resize(max+1);
- for (int i=0; i<max+1; ++i) {
- vals[i] = frand();
- }
- std::sort(begin, end, [&vals](int a, int b){
- return vals.at(a) < vals.at(b);
- });
- }
- #define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
- { \
- const std::string skey(key); \
- const int kid = gguf_find_key(ctx, skey.c_str()); \
- if (kid >= 0) { \
- enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
- if (ktype != (type)) { \
- die_fmt("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype)); \
- } \
- (dst) = func(ctx, kid); \
- } else if (req) { \
- die_fmt("key not found in model: %s", skey.c_str()); \
- } \
- }
- bool are_same_layout(struct ggml_tensor * a, struct ggml_tensor * b) {
- GGML_ASSERT(a != NULL);
- GGML_ASSERT(b != NULL);
- GGML_ASSERT(a->type == b->type);
- GGML_ASSERT(ggml_are_same_shape(a, b));
- GGML_ASSERT(ggml_is_contiguous(a) && ggml_is_contiguous(b));
- return true;
- }
- void read_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name) {
- if (dst == NULL) {
- return;
- }
- struct ggml_tensor * t = ggml_get_tensor(ctx, name);
- GGML_ASSERT(are_same_layout(dst, t));
- memcpy(dst->data, t->data, ggml_nbytes(t));
- if (strlen(ggml_get_name(dst)) == 0) {
- ggml_set_name(dst, name);
- }
- }
- void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt) {
- // NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
- uint32_t file_version;
- GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_FILE_VERSION);
- GGML_ASSERT(file_version == 0);
- GGUF_GET_KEY(fctx, opt->params.past, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT);
- GGUF_GET_KEY(fctx, opt->iter, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ITERATION_COUNT);
- GGUF_GET_KEY(fctx, opt->just_initialized, gguf_get_val_bool, GGUF_TYPE_BOOL, true, LLM_KV_OPTIMIZER_JUST_INITIALIZED);
- uint64_t nx;
- GGUF_GET_KEY(fctx, nx, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_OPTIMIZER_PARAMETER_COUNT);
- opt->nx = (size_t) nx;
- // don't call ggml_opt_init until optimizer type and optimizer specific parameters are know
- std::string opt_type;
- GGUF_GET_KEY(fctx, opt_type, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_OPTIMIZER_TYPE);
- if (opt_type == LLM_KV_OPTIMIZER_TYPE_ADAM) {
- opt->params.type = GGML_OPT_ADAM;
- GGUF_GET_KEY(fctx, opt->adam.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS);
- GGUF_GET_KEY(fctx, opt->adam.fx_prev, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS);
- GGUF_GET_KEY(fctx, opt->adam.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT);
- GGML_ASSERT(opt->ctx != NULL);
- ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
- read_tensor_by_name(opt->adam.m, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
- read_tensor_by_name(opt->adam.v, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
- read_tensor_by_name(opt->adam.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
- } else if (opt_type == LLM_KV_OPTIMIZER_TYPE_LBFGS) {
- opt->params.type = GGML_OPT_LBFGS;
- GGUF_GET_KEY(fctx, opt->params.lbfgs.m, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT);
- GGUF_GET_KEY(fctx, opt->lbfgs.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS);
- GGUF_GET_KEY(fctx, opt->lbfgs.step, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP);
- GGUF_GET_KEY(fctx, opt->lbfgs.j, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J);
- GGUF_GET_KEY(fctx, opt->lbfgs.k, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K);
- GGUF_GET_KEY(fctx, opt->lbfgs.end, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END);
- GGUF_GET_KEY(fctx, opt->lbfgs.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT);
- GGML_ASSERT(opt->ctx != NULL);
- ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
- read_tensor_by_name(opt->lbfgs.x, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
- read_tensor_by_name(opt->lbfgs.xp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
- read_tensor_by_name(opt->lbfgs.g, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
- read_tensor_by_name(opt->lbfgs.gp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
- read_tensor_by_name(opt->lbfgs.d, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
- read_tensor_by_name(opt->lbfgs.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
- read_tensor_by_name(opt->lbfgs.lmal, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
- read_tensor_by_name(opt->lbfgs.lmys, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
- read_tensor_by_name(opt->lbfgs.lms, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
- read_tensor_by_name(opt->lbfgs.lmy, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
- } else {
- die("unknown optimizer type");
- }
- }
- void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt) {
- gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_FILE_VERSION, 0);
- gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT, opt->params.past);
- gguf_set_val_u64(fctx, LLM_KV_OPTIMIZER_PARAMETER_COUNT, (uint64_t) opt->nx);
- gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ITERATION_COUNT, opt->iter);
- gguf_set_val_bool(fctx, LLM_KV_OPTIMIZER_JUST_INITIALIZED, opt->just_initialized);
- switch (opt->params.type) {
- case GGML_OPT_ADAM:
- {
- gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_ADAM);
- gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS, opt->adam.fx_best);
- gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS, opt->adam.fx_prev);
- gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT, opt->adam.n_no_improvement);
- ggml_set_name(opt->adam.m, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
- ggml_set_name(opt->adam.v, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
- if (opt->adam.pf) {
- ggml_set_name(opt->adam.pf, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
- }
- gguf_add_tensor(fctx, opt->adam.m);
- gguf_add_tensor(fctx, opt->adam.v);
- if (opt->adam.pf) {
- gguf_add_tensor(fctx, opt->adam.pf);
- }
- } break;
- case GGML_OPT_LBFGS:
- {
- gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS);
- gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, opt->params.lbfgs.m);
- gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS, opt->lbfgs.fx_best);
- gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP, opt->lbfgs.step);
- gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J, opt->lbfgs.j);
- gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K, opt->lbfgs.k);
- gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END, opt->lbfgs.end);
- gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT, opt->lbfgs.n_no_improvement);
- ggml_set_name(opt->lbfgs.x, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
- ggml_set_name(opt->lbfgs.xp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
- ggml_set_name(opt->lbfgs.g, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
- ggml_set_name(opt->lbfgs.gp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
- ggml_set_name(opt->lbfgs.d, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
- if (opt->lbfgs.pf) {
- ggml_set_name(opt->lbfgs.pf, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
- }
- ggml_set_name(opt->lbfgs.lmal, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
- ggml_set_name(opt->lbfgs.lmys, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
- ggml_set_name(opt->lbfgs.lms, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
- ggml_set_name(opt->lbfgs.lmy, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
- gguf_add_tensor(fctx, opt->lbfgs.x);
- gguf_add_tensor(fctx, opt->lbfgs.xp);
- gguf_add_tensor(fctx, opt->lbfgs.g);
- gguf_add_tensor(fctx, opt->lbfgs.gp);
- gguf_add_tensor(fctx, opt->lbfgs.d);
- if (opt->lbfgs.pf) {
- gguf_add_tensor(fctx, opt->lbfgs.pf);
- }
- gguf_add_tensor(fctx, opt->lbfgs.lmal);
- gguf_add_tensor(fctx, opt->lbfgs.lmys);
- gguf_add_tensor(fctx, opt->lbfgs.lms);
- gguf_add_tensor(fctx, opt->lbfgs.lmy);
- } break;
- }
- }
- void load_llama_model_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model) {
- // NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
- std::string arch;
- std::vector<char> keybuf;
- keybuf.resize(512);
- auto kv = [&arch, &keybuf](const char * key) -> const char * {
- snprintf(keybuf.data(), keybuf.size(), key, arch.c_str());
- return keybuf.data();
- };
- std::vector<char> tn_buf;
- tn_buf.resize(GGML_MAX_NAME);
- auto tn = [&tn_buf](const char * key) -> const char * {
- snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", key);
- return tn_buf.data();
- };
- auto tni = [&tn_buf](const char * key, int bid) -> const char * {
- snprintf(tn_buf.data(), tn_buf.size(), key, bid);
- std::string s = tn_buf.data();
- snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", s.c_str());
- return tn_buf.data();
- };
- GGUF_GET_KEY(fctx, arch, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_GENERAL_ARCHITECTURE);
- GGML_ASSERT(arch == "llama");
- uint32_t ftype_u;
- GGUF_GET_KEY(fctx, ftype_u, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_GENERAL_FILE_TYPE);
- GGML_ASSERT((enum llama_ftype) ftype_u == LLAMA_FTYPE_ALL_F32);
- // n_ctx was not saved in earlier checkpoint file versions, so we make it optional here
- GGUF_GET_KEY(fctx, model->hparams.n_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_CONTEXT_LENGTH));
- GGUF_GET_KEY(fctx, model->hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH));
- GGUF_GET_KEY(fctx, model->hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH));
- GGUF_GET_KEY(fctx, model->hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
- GGUF_GET_KEY(fctx, model->hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
- model->hparams.n_rot = model->hparams.n_embd / model->hparams.n_head;
- GGUF_GET_KEY(fctx, model->hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT));
- float rope_freq_scale = 1.0f;
- GGUF_GET_KEY(fctx, model->hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
- GGUF_GET_KEY(fctx, model->hparams.rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
- GGUF_GET_KEY(fctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
- if (rope_freq_scale != 1.0f) {
- model->hparams.rope_freq_scale = 1.0f / rope_freq_scale;
- }
- init_model(model);
- read_tensor_by_name(model->tok_embeddings, f_ggml_ctx, tn(LLM_TENSOR_TOKEN_EMBD));
- read_tensor_by_name(model->norm, f_ggml_ctx, tn(LLM_TENSOR_OUTPUT_NORM));
- read_tensor_by_name(model->output, f_ggml_ctx, tn(LLM_TENSOR_OUTPUT));
- for (uint32_t i = 0; i < model->hparams.n_layer; ++i) {
- auto & layer = model->layers[i];
- read_tensor_by_name(layer.attention_norm, f_ggml_ctx, tni(LLM_TENSOR_ATTN_NORM, i));
- read_tensor_by_name(layer.wq, f_ggml_ctx, tni(LLM_TENSOR_ATTN_Q, i));
- read_tensor_by_name(layer.wk, f_ggml_ctx, tni(LLM_TENSOR_ATTN_K, i));
- read_tensor_by_name(layer.wv, f_ggml_ctx, tni(LLM_TENSOR_ATTN_V, i));
- read_tensor_by_name(layer.wo, f_ggml_ctx, tni(LLM_TENSOR_ATTN_OUT, i));
- read_tensor_by_name(layer.ffn_norm, f_ggml_ctx, tni(LLM_TENSOR_FFN_NORM, i));
- read_tensor_by_name(layer.w1, f_ggml_ctx, tni(LLM_TENSOR_FFN_GATE, i));
- read_tensor_by_name(layer.w2, f_ggml_ctx, tni(LLM_TENSOR_FFN_DOWN, i));
- read_tensor_by_name(layer.w3, f_ggml_ctx, tni(LLM_TENSOR_FFN_UP, i));
- }
- }
- void save_llama_model_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model) {
- const char * arch = "llama";
- enum llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
- std::vector<char> keybuf;
- keybuf.resize(512);
- auto kv = [arch, &keybuf](const char * key) -> const char * {
- snprintf(keybuf.data(), keybuf.size(), key, arch);
- return keybuf.data();
- };
- // set arch
- gguf_set_val_str(fctx, LLM_KV_GENERAL_ARCHITECTURE, arch);
- gguf_set_val_u32(fctx, LLM_KV_GENERAL_FILE_TYPE, ftype);
- // set hparams
- gguf_set_val_u32(fctx, kv(LLM_KV_CONTEXT_LENGTH), model->hparams.n_ctx );
- gguf_set_val_u32(fctx, kv(LLM_KV_EMBEDDING_LENGTH), model->hparams.n_embd );
- gguf_set_val_u32(fctx, kv(LLM_KV_FEED_FORWARD_LENGTH), model->hparams.n_ff );
- gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT), model->hparams.n_head );
- gguf_set_val_u32(fctx, kv(LLM_KV_BLOCK_COUNT), model->hparams.n_layer );
- gguf_set_val_u32(fctx, kv(LLM_KV_ROPE_DIMENSION_COUNT), model->hparams.n_rot );
- gguf_set_val_f32(fctx, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS), model->hparams.f_norm_rms_eps );
- gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_FREQ_BASE), model->hparams.rope_freq_base ); // TODO load in llama.cpp
- gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_SCALE_LINEAR), 1.0f / model->hparams.rope_freq_scale );
- // set vocab by copying from vocab_model gguf file
- {
- struct gguf_init_params params = {
- /*.no_alloc = */ false,
- /*.ctx = */ NULL,
- };
- struct gguf_context * vctx = gguf_init_from_file(fn_vocab_model, params);
- const int token_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_LIST));
- if (token_idx == -1) {
- die("cannot find tokenizer vocab in model file");
- }
- const uint32_t n_vocab = gguf_get_arr_n(vctx, token_idx);
- const int score_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_SCORES));
- if (score_idx == -1) {
- die("cannot find tokenizer scores in model file");
- }
- const float * scores = (const float * ) gguf_get_arr_data(vctx, score_idx);
- const int toktype_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE));
- if (toktype_idx == -1) {
- die("cannot find token type list in GGUF file");
- }
- const int * toktypes = (const int * ) gguf_get_arr_data(vctx, toktype_idx);
- std::string tokenizer_name;
- GGUF_GET_KEY(vctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_TOKENIZER_MODEL));
- gguf_set_val_str(fctx, kv(LLM_KV_TOKENIZER_MODEL), tokenizer_name.c_str());
- gguf_set_arr_data(fctx, kv(LLM_KV_TOKENIZER_SCORES), GGUF_TYPE_FLOAT32, scores, n_vocab);
- gguf_set_arr_data(fctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE), GGUF_TYPE_INT32, toktypes, n_vocab);
- int32_t special_bos_id = 1;
- int32_t special_eos_id = 2;
- int32_t special_unk_id = 0;
- int32_t special_sep_id = -1;
- int32_t special_pad_id = -1;
- if (tokenizer_name == "llama") {
- // default special tokens
- special_bos_id = 1;
- special_eos_id = 2;
- special_unk_id = 0;
- special_sep_id = -1;
- special_pad_id = -1;
- } else if (tokenizer_name == "gpt2") {
- // read and copy bpe merges
- const int merges_keyidx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_MERGES));
- if (merges_keyidx == -1) {
- die("cannot find tokenizer merges in model file");
- }
- const int n_merges = gguf_get_arr_n(vctx, merges_keyidx);
- std::vector<const char*> merges;
- merges.resize(n_merges);
- for (int i = 0; i < n_merges; i++) {
- merges[i] = gguf_get_arr_str(vctx, merges_keyidx, i);
- }
- gguf_set_arr_str(fctx, kv(LLM_KV_TOKENIZER_MERGES), merges.data(), n_merges);
- // default special tokens
- special_bos_id = 11;
- special_eos_id = 11;
- special_unk_id = -1;
- special_sep_id = -1;
- special_pad_id = -1;
- } else {
- fprintf(stderr, "%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
- fprintf(stderr, "%s: using default tokenizer: 'llama'", __func__);
- }
- std::vector<const char*> tokens;
- tokens.resize(n_vocab);
- for (uint32_t i = 0; i < n_vocab; i++) {
- tokens[i] = gguf_get_arr_str(vctx, token_idx, i);
- }
- gguf_set_arr_str(fctx, kv(LLM_KV_TOKENIZER_LIST), tokens.data(), n_vocab);
- GGUF_GET_KEY(vctx, special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_BOS_ID));
- GGUF_GET_KEY(vctx, special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_EOS_ID));
- GGUF_GET_KEY(vctx, special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID));
- GGUF_GET_KEY(vctx, special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID));
- GGUF_GET_KEY(vctx, special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID));
- gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_BOS_ID), special_bos_id);
- gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_EOS_ID), special_eos_id);
- gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_UNK_ID), special_unk_id);
- gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_SEP_ID), special_sep_id);
- gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_PAD_ID), special_pad_id);
- gguf_free(vctx);
- }
- // add tensors
- gguf_add_tensor(fctx, model->tok_embeddings);
- gguf_add_tensor(fctx, model->norm);
- gguf_add_tensor(fctx, model->output);
- for (uint32_t i = 0; i < model->hparams.n_layer; ++i) {
- auto & layer = model->layers[i];
- gguf_add_tensor(fctx, layer.attention_norm);
- gguf_add_tensor(fctx, layer.wq);
- gguf_add_tensor(fctx, layer.wk);
- gguf_add_tensor(fctx, layer.wv);
- gguf_add_tensor(fctx, layer.wo);
- gguf_add_tensor(fctx, layer.ffn_norm);
- gguf_add_tensor(fctx, layer.w1);
- gguf_add_tensor(fctx, layer.w2);
- gguf_add_tensor(fctx, layer.w3);
- }
- }
- void save_llama_model_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model) {
- struct gguf_context * fctx = gguf_init_empty();
- save_llama_model_gguf(fctx, fn_vocab_model, model);
- // write file
- const bool only_meta = false;
- gguf_write_to_file(fctx, filename, only_meta);
- gguf_free(fctx);
- }
- void load_checkpoint_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct ggml_opt_context * opt) {
- load_llama_model_gguf(fctx, f_ggml_ctx, model);
- uint32_t file_version;
- GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION);
- GGML_ASSERT(file_version == 0);
- GGUF_GET_KEY(fctx, model->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT);
- GGUF_GET_KEY(fctx, model->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT);
- GGUF_GET_KEY(fctx, model->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT);
- load_opt_context_gguf(fctx, f_ggml_ctx, opt);
- }
- void save_checkpoint_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model, struct ggml_opt_context * opt) {
- save_llama_model_gguf(fctx, fn_vocab_model, model);
- gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 0);
- gguf_set_val_u32(fctx, LLM_KV_TRAINING_ITERATION_COUNT, model->train_its);
- gguf_set_val_u32(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, model->train_samples);
- gguf_set_val_u32(fctx, LLM_KV_TRAINING_TOKEN_COUNT, model->train_tokens);
- save_opt_context_gguf(fctx, opt);
- }
- bool load_checkpoint_file(const char * filename, struct my_llama_model * model, struct ggml_opt_context * opt) {
- struct ggml_context * f_ggml_ctx;
- struct gguf_init_params params;
- params.no_alloc = false;
- params.ctx = &f_ggml_ctx;
- struct gguf_context * fctx = gguf_init_from_file(filename, params);
- if (fctx == NULL) {
- return false;
- }
- load_checkpoint_gguf(fctx, f_ggml_ctx, model, opt);
- return true;
- }
- void save_checkpoint_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, struct ggml_opt_context * opt) {
- struct gguf_context * fctx = gguf_init_empty();
- save_checkpoint_gguf(fctx, fn_vocab_model, model, opt);
- // write file
- const bool only_meta = false;
- gguf_write_to_file(fctx, filename, only_meta);
- gguf_free(fctx);
- }
- float cosine_decay(const int decay_steps, const float minimum, int step) {
- if (step > decay_steps) {
- step = decay_steps;
- }
- const float cosine_decay = 0.50f*(1.0f + cosf(3.14159265359f*step/decay_steps));
- const float decay = (1 - minimum)*cosine_decay + minimum;
- return decay;
- }
- float cosine_decay_restart(int decay_steps, const float minimum, int step, float restart_step_mult, bool enable_restart) {
- if (enable_restart) {
- while (step > decay_steps) {
- step -= decay_steps;
- decay_steps = (int) restart_step_mult * decay_steps;
- }
- }
- return cosine_decay(decay_steps, minimum, step);
- }
- struct train_params {
- const char * fn_vocab_model;
- const char * fn_train_data;
- const char * fn_checkpoint_in;
- const char * fn_checkpoint_out;
- const char * fn_model_out;
- uint32_t seed;
- int n_ctx;
- int n_embd;
- int n_head;
- int n_layer;
- int n_ff;
- int n_threads;
- int n_batch;
- int n_examples;
- float f_norm_rms_eps;
- float rope_freq_base;
- float rope_freq_scale;
- int print_info_interval;
- bool samples_start_after_nl;
- bool use_adam;
- bool use_flash;
- bool use_checkpointing;
- bool use_alloc;
- // only adam
- int warmup;
- int cos_decay_steps;
- float cos_decay_restart;
- float cos_decay_min;
- bool enable_restart;
- int opt_past;
- float opt_delta;
- int opt_max_no_improvement;
- int lbfgs_n_iter;
- int adam_n_iter;
- float adam_alpha;
- float adam_min_alpha;
- float adam_decay;
- int adam_decay_min_ndim;
- float adam_beta1;
- float adam_beta2;
- float adam_gclip;
- float adam_eps_f;
- int mem_model_gb;
- int mem_compute_gb;
- int mem_compute0_gb;
- };
- struct train_params get_default_train_params() {
- struct train_params params;
- params.fn_vocab_model = "ggml-vic7b-uncensored-q4_0.bin";
- params.fn_train_data = "shakespeare.txt";
- params.fn_checkpoint_in = "checkpoint.bin";
- params.fn_checkpoint_out = "checkpoint.bin";
- params.fn_model_out = "ggml-checkpoint-f32.bin";
- params.seed = -1;
- params.n_ctx = 128;
- params.n_embd = 256;
- params.n_head = 8;
- params.n_layer = 16;
- params.n_ff = 768;
- params.n_threads = 6;
- params.n_batch = 8;
- params.n_examples = 1;
- params.f_norm_rms_eps = 1e-5;
- params.rope_freq_base = 10000.0f;
- params.rope_freq_scale = 1.0f;
- params.print_info_interval = 1;
- params.samples_start_after_nl = false;
- params.use_adam = true;
- params.use_flash = true;
- params.use_checkpointing = true;
- params.use_alloc = true;
- params.opt_past = 0;
- params.opt_delta = 1e-5f;
- params.opt_max_no_improvement = 0;
- // only adam
- params.warmup = 100;
- params.cos_decay_steps = 1000;
- params.cos_decay_restart = 1.1f;
- params.cos_decay_min = 0.1f;
- params.enable_restart = false;
- params.lbfgs_n_iter = 256;
- params.adam_n_iter = 256;
- params.adam_alpha = 1e-3f;
- params.adam_min_alpha = 0;
- params.adam_decay = 1e-1f;
- params.adam_decay_min_ndim = 2;
- params.adam_beta1 = 0.9f;
- params.adam_beta2 = 0.999f;
- params.adam_gclip = 1.0f;
- params.adam_eps_f = 0.0f;
- params.mem_model_gb = 2;
- params.mem_compute_gb = 24;
- params.mem_compute0_gb = 8;
- return params;
- }
- void train_print_usage(int /*argc*/, char ** argv, const struct train_params * params) {
- fprintf(stderr, "usage: %s [options]\n", argv[0]);
- fprintf(stderr, "\n");
- fprintf(stderr, "options:\n");
- fprintf(stderr, " -h, --help show this help message and exit\n");
- fprintf(stderr, " --vocab-model FNAME model path from which to load vocab (default '%s')\n", params->fn_vocab_model);
- fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data);
- fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
- fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
- fprintf(stderr, " --model-out FNAME path to save ggml model (default '%s')\n", params->fn_model_out);
- fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n");
- fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
- fprintf(stderr, " --embd N Embedding size used for new models (default %d)\n", params->n_embd);
- fprintf(stderr, " --ff N Feedforward size used for new models. (default %d)\n", params->n_ff);
- fprintf(stderr, " --head N Number of heads for new models (default %d)\n", params->n_head);
- fprintf(stderr, " --layer N Number of layers for new models (default %d)\n", params->n_layer);
- fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps);
- fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
- fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
- fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
- fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
- fprintf(stderr, " -n N, --examples N Number of examples to train (default %d)\n", params->n_examples);
- fprintf(stderr, " --print-info-interval N Print infos during training each N examples (default %d)\n", params->print_info_interval);
- fprintf(stderr, " --samples-after-nl Training samples start after newlines. (default %s)\n", params->samples_start_after_nl ? "on" : "off");
- fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n");
- fprintf(stderr, " --use-adam Use Adam optimizer (default)\n");
- fprintf(stderr, " --no-flash Don't use flash attention \n");
- fprintf(stderr, " --use-flash Use flash attention (default)\n");
- fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
- fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n");
- fprintf(stderr, " --no-alloc Don't use allocator\n");
- fprintf(stderr, " --use-alloc Use allocator (default)\n");
- fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup);
- fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
- fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
- fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min);
- fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : "");
- fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : "");
- fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
- fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
- fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
- fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
- fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
- fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
- fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha);
- fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
- fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim);
- fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
- fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
- fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
- fprintf(stderr, " --lbfgs-iter N Maximum number of LBFGS optimization iterations for each batch (default %d)\n", params->lbfgs_n_iter);
- fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb);
- fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb);
- fprintf(stderr, " --mem-compute0 N Memory to allocate for automatic memory allocator in gigabytes. (default %d)\n", params->mem_compute0_gb);
- fprintf(stderr, "\n");
- }
- bool train_params_parse(int argc, char ** argv, struct train_params * params) {
- bool invalid_param = false;
- std::string arg;
- struct train_params default_params = get_default_train_params();
- const std::string arg_prefix = "--";
- for (int i = 1; i < argc; i++) {
- arg = argv[i];
- if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
- std::replace(arg.begin(), arg.end(), '_', '-');
- }
- if (arg == "--vocab-model") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->fn_vocab_model = argv[i];
- } else if (arg == "--train-data") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->fn_train_data = argv[i];
- } else if (arg == "--checkpoint-in") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->fn_checkpoint_in = argv[i];
- } else if (arg == "--checkpoint-out") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->fn_checkpoint_out = argv[i];
- } else if (arg == "--model-out") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->fn_model_out = argv[i];
- } else if (arg == "-s" || arg == "--seed") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->seed = std::stoi(argv[i]);
- } else if (arg == "-c" || arg == "--ctx") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->n_ctx = std::stoi(argv[i]);
- } else if (arg == "--embd") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->n_embd = std::stoi(argv[i]);
- } else if (arg == "--ff") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->n_ff = std::stoi(argv[i]);
- } else if (arg == "--head") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->n_head = std::stoi(argv[i]);
- } else if (arg == "--layer") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->n_layer = std::stoi(argv[i]);
- } else if (arg == "--norm-rms-eps") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->f_norm_rms_eps = std::stof(argv[i]);
- } else if (arg == "--rope-freq-base") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->rope_freq_base = std::stof(argv[i]);
- } else if (arg == "--rope-freq-scale") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->rope_freq_scale = std::stof(argv[i]);
- } else if (arg == "-t" || arg == "--threads") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->n_threads = std::stoi(argv[i]);
- } else if (arg == "-b" || arg == "--batch") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->n_batch = std::stoi(argv[i]);
- } else if (arg == "-n" || arg == "--examples") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->n_examples = std::stoi(argv[i]);
- } else if (arg == "--print-info-interval") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->print_info_interval = std::stoi(argv[i]);
- } else if (arg == "--samples-after-nl") {
- params->samples_start_after_nl = true;
- } else if (arg == "--use-lbfgs") {
- params->use_adam = false;
- } else if (arg == "--use-adam") {
- params->use_adam = true;
- } else if (arg == "--no-flash") {
- params->use_flash = false;
- } else if (arg == "--use-flash") {
- params->use_flash = true;
- } else if (arg == "--no-checkpointing") {
- params->use_checkpointing = false;
- } else if (arg == "--use-checkpointing") {
- params->use_checkpointing = true;
- } else if (arg == "--no-alloc") {
- params->use_alloc = false;
- } else if (arg == "--use-alloc") {
- params->use_alloc = true;
- } else if (arg == "--warmup") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->warmup = std::stoi(argv[i]);
- } else if (arg == "--cos-decay-steps") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->cos_decay_steps = std::stof(argv[i]);
- } else if (arg == "--cos-decay-restart") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->cos_decay_restart = std::stof(argv[i]);
- } else if (arg == "--cos-decay-min") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->cos_decay_min = std::stof(argv[i]);
- } else if (arg == "--enable-restart") {
- params->enable_restart = true;
- } else if (arg == "--disable-restart") {
- params->enable_restart = false;
- } else if (arg == "--opt-past") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->opt_past = std::stoi(argv[i]);
- } else if (arg == "--opt-delta") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->opt_delta = std::stof(argv[i]);
- } else if (arg == "--opt-max-no-improvement") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->opt_max_no_improvement = std::stoi(argv[i]);
- } else if (arg == "--adam-epsf") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->adam_eps_f = std::stof(argv[i]);
- } else if (arg == "--adam-iter") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->adam_n_iter = std::stoi(argv[i]);
- } else if (arg == "--adam-alpha") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->adam_alpha = std::stof(argv[i]);
- } else if (arg == "--adam-min-alpha") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->adam_min_alpha = std::stof(argv[i]);
- } else if (arg == "--adam-decay") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->adam_decay = std::stof(argv[i]);
- } else if (arg == "--adam-decay-min-ndim") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->adam_decay_min_ndim = std::stoi(argv[i]);
- } else if (arg == "--adam-beta1") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->adam_beta1 = std::stof(argv[i]);
- } else if (arg == "--adam-beta2") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->adam_beta2 = std::stof(argv[i]);
- } else if (arg == "--adam-gclip") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->adam_gclip = std::stof(argv[i]);
- } else if (arg == "--lbfgs-iter") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->lbfgs_n_iter = std::stoi(argv[i]);
- } else if (arg == "--mem-model") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->mem_model_gb = std::stoi(argv[i]);
- } else if (arg == "--mem-compute") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->mem_compute_gb = std::stoi(argv[i]);
- } else if (arg == "--mem-compute0") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->mem_compute0_gb = std::stoi(argv[i]);
- } else if (arg == "-h" || arg == "--help") {
- train_print_usage(argc, argv, &default_params);
- exit(0);
- } else {
- fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
- train_print_usage(argc, argv, &default_params);
- exit(1);
- }
- }
- if (invalid_param) {
- fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
- train_print_usage(argc, argv, &default_params);
- exit(1);
- }
- return true;
- }
- struct opt_callback_data {
- struct train_params * params;
- struct ggml_opt_context * opt;
- struct llama_context * lctx;
- llama_token * tokens_data;
- size_t tokens_size;
- int * samples_data;
- size_t samples_size;
- int shuffle_countdown;
- struct ggml_tensor * tokens_input;
- struct ggml_tensor * target_logits;
- struct ggml_tensor * target_probs;
- };
- void opt_callback(void * vdata, float * sched) {
- struct opt_callback_data * data = (struct opt_callback_data *) vdata;
- struct train_params * params = data->params;
- struct ggml_opt_context * opt = data->opt;
- int n_batch = params->n_batch;
- *sched = (opt->iter < params->warmup)
- ? (float) opt->iter / (float) params->warmup
- : cosine_decay_restart(
- params->cos_decay_steps,
- params->cos_decay_min,
- opt->iter - params->warmup,
- params->cos_decay_restart,
- params->enable_restart);
- float min_sched = params->adam_min_alpha / params->adam_alpha;
- *sched = min_sched + *sched * (1.0f - min_sched);
- int impr_plot = std::isnan(opt->loss_after) ? 0 : -std::lround(1 + (opt->loss_before - opt->loss_after) * 10.0f);
- printf("%s: iter=%*d, sched=%f loss0=%f loss=%f | improvement: %*d>\n", __func__, 6, opt->iter, *sched, opt->loss_before, opt->loss_after, impr_plot, (int)0);
- if (data->shuffle_countdown < n_batch) {
- printf("%s: reshuffle samples\n", __func__);
- shuffle_ints(data->samples_data, data->samples_data + data->samples_size);
- for (int i = 0; i < (int) data->samples_size; ++i) {
- GGML_ASSERT(data->samples_data[i]+params->n_ctx-1 < (int) data->tokens_size);
- }
- data->shuffle_countdown = data->samples_size;
- }
- get_example_targets_batch(
- data->lctx,
- data->samples_data,
- data->samples_size,
- data->tokens_data,
- data->tokens_size,
- opt->iter,
- data->tokens_input,
- data->target_logits,
- data->target_probs);
- data->shuffle_countdown -= n_batch;
- }
- int main(int argc, char ** argv) {
- struct train_params params = get_default_train_params();
- if (!train_params_parse(argc, argv, ¶ms)) {
- return 1;
- }
- if (params.seed == LLAMA_DEFAULT_SEED) {
- params.seed = time(NULL);
- }
- printf("%s: seed: %u\n", __func__, params.seed);
- srand(params.seed);
- struct llama_context_params llama_params = llama_context_default_params();
- llama_params.vocab_only = true;
- struct llama_model * lmodel = llama_load_model_from_file(params.fn_vocab_model, llama_params);
- struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params);
- printf("%s: tokenize training data\n", __func__);
- std::vector<llama_token> train_tokens;
- if (tokenize_file(lctx, params.fn_train_data, train_tokens) < 0) {
- fprintf(stderr, "%s: failed to tokenize file '%s'\n", __func__, params.fn_train_data);
- }
- printf("%s: number of training tokens: %d\n", __func__, (int) train_tokens.size());
- struct my_llama_model model;
- model.hparams.n_vocab = llama_n_vocab(lctx);
- model.hparams.n_ctx = params.n_ctx;
- model.hparams.n_embd = params.n_embd;
- model.hparams.n_head = params.n_head;
- model.hparams.n_layer = params.n_layer;
- model.hparams.n_ff = params.n_ff;
- // llama.cpp requires n_rot to be exactly n_embd / n_head
- model.hparams.n_rot = model.hparams.n_embd / model.hparams.n_head;
- model.hparams.f_norm_rms_eps = params.f_norm_rms_eps;
- model.hparams.rope_freq_base = params.rope_freq_base;
- model.hparams.rope_freq_scale = params.rope_freq_scale;
- print_params(&model.hparams);
- std::vector<size_t> token_noccurs;
- std::vector<bool> token_notavail;
- token_noccurs.resize(model.hparams.n_vocab, 0);
- token_notavail.resize(model.hparams.n_vocab, true);
- for (int i = 0; i < (int) train_tokens.size(); ++i) {
- ++token_noccurs[train_tokens[i]];
- token_notavail[train_tokens[i]] = false;
- }
- std::vector<float> token_freq;
- token_freq.resize(model.hparams.n_vocab, 0);
- int n_unique_tokens = 0;
- for (int i = 0; i < (int) token_noccurs.size(); ++i) {
- token_freq[i] = (float) token_noccurs[i] / (float) train_tokens.size();
- n_unique_tokens += (token_noccurs[i] > 0) ? 1 : 0;
- }
- printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens);
- struct ggml_init_params lcparams;
- lcparams.mem_size = 1024ll*1024ll*1024ll*((size_t) params.mem_model_gb);
- lcparams.mem_buffer = NULL;
- lcparams.no_alloc = false;
- model.ctx = ggml_init(lcparams);
- int n_tokens = model.hparams.n_ctx;
- int n_vocab = model.hparams.n_vocab;
- int n_batch = params.n_batch;
- struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context));
- memset(opt, 0, sizeof(struct ggml_opt_context));
- struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
- struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
- opt_params_adam.print_forward_graph = false;
- opt_params_adam.print_backward_graph = false;
- opt_params_adam.n_threads = params.n_threads;
- opt_params_adam.past = params.opt_past;
- opt_params_adam.delta = params.opt_delta;
- opt_params_adam.max_no_improvement = params.opt_max_no_improvement;
- opt_params_adam.adam.n_iter = params.adam_n_iter;
- opt_params_adam.adam.sched = 1.0f;
- opt_params_adam.adam.alpha = params.adam_alpha;
- opt_params_adam.adam.decay = params.adam_decay;
- opt_params_adam.adam.decay_min_ndim = params.adam_decay_min_ndim;
- opt_params_adam.adam.beta1 = params.adam_beta1;
- opt_params_adam.adam.beta2 = params.adam_beta2;
- opt_params_adam.adam.gclip = params.adam_gclip;
- opt_params_adam.adam.eps_f = params.adam_eps_f;
- opt_params_lbfgs.print_forward_graph = false;
- opt_params_lbfgs.print_backward_graph = false;
- opt_params_lbfgs.n_threads = params.n_threads;
- opt_params_adam.past = params.opt_past;
- opt_params_adam.delta = params.opt_delta;
- opt_params_adam.max_no_improvement = params.opt_max_no_improvement;
- opt_params_lbfgs.lbfgs.n_iter = params.lbfgs_n_iter;
- opt->ctx = model.ctx;
- opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs;
- printf("%s: init model\n", __func__);
- bool existed = load_checkpoint_file(params.fn_checkpoint_in, &model, opt);
- if (!existed) {
- init_model(&model);
- }
- set_param_model(&model);
- opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs;
- opt->iter = model.train_its;
- printf("%s: opt iter %d\n", __func__, opt->iter);
- bool from_scratch = !existed;
- if (from_scratch) {
- randomize_model(&model, params.seed, 0.0f, 1.0f, -1.0f, +1.0f);
- }
- printf("used_mem model: %zu bytes\n", ggml_used_mem(model.ctx));
- // ggml_print_tensor_objects(model.ctx);
- // TODO: use std::vector<uint8_t> intead of "new"
- size_t compute_size = 1024ll*1024ll*1024ll*((size_t) params.mem_compute_gb);
- uint8_t * compute_addr = new uint8_t[compute_size];
- size_t size_buf_0 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute0_gb);
- uint8_t * compute_buf_0 = new uint8_t[size_buf_0];
- ggml_allocr * alloc = NULL;
- if (params.use_alloc) {
- static const size_t tensor_alignment = 32;
- alloc = ggml_allocr_new(compute_buf_0, size_buf_0, tensor_alignment);
- }
- GGML_ASSERT(n_tokens < (int) train_tokens.size());
- std::vector<int> train_samples;
- train_samples.push_back(0);
- for (int i = 1; i < (int) train_tokens.size() - n_tokens; ++i) {
- if (!params.samples_start_after_nl || (train_tokens[i-1] == llama_token_nl(lctx))) {
- train_samples.push_back(i);
- }
- }
- shuffle_ints(train_samples.data(), train_samples.data() + train_samples.size());
- for (int i = 0; i < (int) train_samples.size(); ++i) {
- GGML_ASSERT(train_samples[i]+n_tokens-1 < (int) train_tokens.size());
- }
- printf("%s: begin training\n", __func__);
- struct opt_callback_data opt_cb_data;
- opt_cb_data.params = ¶ms;
- opt_cb_data.opt = opt;
- opt_cb_data.lctx = lctx;
- opt_cb_data.tokens_data = train_tokens.data();
- opt_cb_data.tokens_size = train_tokens.size();
- opt_cb_data.samples_data = train_samples.data();
- opt_cb_data.samples_size = train_samples.size();
- opt_cb_data.shuffle_countdown = train_samples.size();
- opt_cb_data.tokens_input = NULL;
- opt_cb_data.target_logits = NULL;
- opt_cb_data.target_probs = NULL;
- int64_t t0 = ggml_time_ms();
- for (int ex = 0; ex < params.n_examples; ++ex) {
- if (ex*n_batch >= (int) train_samples.size()) {
- shuffle_ints(train_samples.data(), train_samples.data() + train_samples.size());
- for (int i = 0; i < (int) train_samples.size(); ++i) {
- GGML_ASSERT(train_samples[i]+n_tokens-1 < (int) train_tokens.size());
- }
- }
- struct ggml_init_params cparams = {
- compute_size, // mem_size
- compute_addr, // mem_buffer
- false, // no_alloc
- };
- struct ggml_context * ctx0 = ggml_init(cparams);
- ggml_set_no_alloc(ctx0, false);
- // don't use alloc for input tensors, so we can safely fill them with data
- //struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
- //struct ggml_tensor * after_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
- struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
- struct ggml_tensor * target_logits = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
- struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
- ggml_set_no_alloc(ctx0, (alloc != NULL));
- if (alloc) {
- ggml_allocr_reset(alloc);
- }
- opt_cb_data.tokens_input = tokens_input;
- opt_cb_data.target_logits = target_logits;
- opt_cb_data.target_probs = target_probs;
- int n_past = 0;
- struct ggml_cgraph * gf = ggml_new_graph(ctx0);
- struct ggml_cgraph * gb = ggml_new_graph(ctx0);
- struct ggml_cgraph * gb_tmp = params.use_checkpointing
- ? ggml_new_graph(ctx0)
- : NULL;
- GGML_ASSERT(n_past == 0);
- struct ggml_tensor * loss = NULL;
- struct ggml_tensor * logits = NULL;
- loss = llama_build_train_graphs(
- &model, alloc, ctx0,
- gf, gb, gb_tmp,
- &logits, tokens_input, target_probs,
- n_tokens, n_batch,
- params.use_flash,
- params.use_checkpointing
- );
- size_t used_mem_before_opt = ggml_used_mem(ctx0);
- opt->params.adam.sched = (opt->iter < params.warmup)
- ? (float) opt->iter / (float) params.warmup
- : cosine_decay_restart(
- params.cos_decay_steps,
- params.cos_decay_min,
- opt->iter - params.warmup,
- params.cos_decay_restart,
- params.enable_restart);
- float min_sched = params.adam_min_alpha / params.adam_alpha;
- opt->params.adam.sched = min_sched + opt->params.adam.sched * (1.0f - min_sched);
- printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched);
- ggml_opt_resume_g(ctx0, opt, loss, gf, gb, &opt_callback, (void *) &opt_cb_data);
- size_t used_mem_after_opt = ggml_used_mem(ctx0);
- int n_iter = params.use_adam ? params.adam_n_iter : params.lbfgs_n_iter;
- model.train_its = opt->iter;
- model.train_samples += n_batch * n_iter;
- model.train_tokens += n_batch * n_tokens * n_iter;
- if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) {
- printf("Example %d, opt iter %d\n", ex, opt->iter);
- printf("error_before_opt: %.6f\n", opt->loss_before);
- printf("error_after_opt: %.6f\n", opt->loss_after);
- printf("used_mem_before_opt: %zu bytes\n", used_mem_before_opt);
- printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt);
- }
- ggml_free(ctx0);
- }
- int64_t t1 = ggml_time_ms();
- int64_t d = t1-t0;
- double dd = (double) d * 1e-3;
- printf("%s: total training time=%f seconds\n", __func__, dd);
- if (params.n_examples > 0) {
- save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, opt);
- }
- if (strlen(params.fn_model_out) > 0) {
- save_llama_model_file(params.fn_model_out, params.fn_vocab_model, &model);
- }
- if (alloc) {
- ggml_allocr_free(alloc);
- }
- delete[] compute_addr;
- delete[] compute_buf_0;
- ggml_free(model.ctx);
- llama_free(lctx);
- llama_free_model(lmodel);
- return 0;
- }
|