1
0

train.cpp 65 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515
  1. #include "train.h"
  2. #include "common.h"
  3. #include <algorithm>
  4. #include <random>
  5. #include <sstream>
  6. #include <functional>
  7. #include <cstring>
  8. struct random_normal_distribution {
  9. std::mt19937 gen;
  10. std::normal_distribution<float> rd;
  11. float min;
  12. float max;
  13. };
  14. struct random_uniform_distribution {
  15. std::mt19937 gen;
  16. std::uniform_real_distribution<float> rd;
  17. };
  18. struct train_state * init_train_state() {
  19. struct train_state * state = new struct train_state;
  20. state->train_its = 0;
  21. state->train_samples = 0;
  22. state->train_tokens = 0;
  23. state->train_epochs = 0;
  24. state->shuffle_samples_hash = 0;
  25. state->shuffle_sample_count = 0;
  26. state->shuffle_next_sample = 0;
  27. state->shuffle_rng_state_current = "";
  28. state->shuffle_rng_state_next = "";
  29. state->opt = new struct ggml_opt_context;
  30. state->opt->ctx = NULL;
  31. state->opt->params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM);
  32. state->opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
  33. state->opt->loss_after = 0.0f;
  34. return state;
  35. }
  36. void free_train_state(struct train_state * state) {
  37. delete state->opt;
  38. delete state;
  39. }
  40. struct random_normal_distribution * init_random_normal_distribution(
  41. int seed, float mean, float std, float min, float max
  42. ) {
  43. struct random_normal_distribution * rnd = (struct random_normal_distribution *) malloc(sizeof(struct random_normal_distribution));
  44. rnd->gen = std::mt19937(seed);
  45. rnd->rd = std::normal_distribution<float>{mean, std};
  46. rnd->min = min;
  47. rnd->max = max;
  48. return rnd;
  49. }
  50. struct random_uniform_distribution * init_random_uniform_distribution(int seed, float min, float max) {
  51. struct random_uniform_distribution * rnd = (struct random_uniform_distribution *) malloc(sizeof(struct random_uniform_distribution));
  52. rnd->gen = std::mt19937(seed);
  53. rnd->rd = std::uniform_real_distribution<float>{min, max};
  54. return rnd;
  55. }
  56. void free_random_normal_distribution (struct random_normal_distribution * rnd) {
  57. free(rnd);
  58. }
  59. void free_random_uniform_distribution(struct random_uniform_distribution * rnd) {
  60. free(rnd);
  61. }
  62. struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) {
  63. float scale = 1.0f; // xavier
  64. switch (ggml_n_dims(tensor)) {
  65. case 1:
  66. scale /= sqrtf((float) tensor->ne[0]);
  67. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  68. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
  69. *dst = scale * frand_normal(rnd);
  70. }
  71. break;
  72. case 2:
  73. scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
  74. for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
  75. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  76. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
  77. *dst = scale * frand_normal(rnd);
  78. }
  79. }
  80. break;
  81. case 3:
  82. scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
  83. for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
  84. for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
  85. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  86. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
  87. *dst = scale * frand_normal(rnd);
  88. }
  89. }
  90. }
  91. break;
  92. case 4:
  93. scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
  94. for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
  95. for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
  96. for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
  97. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  98. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
  99. *dst = scale * frand_normal(rnd);
  100. }
  101. }
  102. }
  103. }
  104. break;
  105. default:
  106. die("Unsupported tensor->n_dims");
  107. };
  108. return tensor;
  109. }
  110. struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd) {
  111. switch (ggml_n_dims(tensor)) {
  112. case 1:
  113. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  114. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
  115. *dst = frand_uniform(rnd);
  116. }
  117. break;
  118. case 2:
  119. for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
  120. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  121. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
  122. *dst = frand_uniform(rnd);
  123. }
  124. }
  125. break;
  126. case 3:
  127. for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
  128. for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
  129. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  130. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
  131. *dst = frand_uniform(rnd);
  132. }
  133. }
  134. }
  135. break;
  136. case 4:
  137. for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
  138. for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
  139. for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
  140. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  141. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
  142. *dst = frand_uniform(rnd);
  143. }
  144. }
  145. }
  146. }
  147. break;
  148. default:
  149. die("Unsupported tensor->n_dims");
  150. };
  151. return tensor;
  152. }
  153. float frand() {
  154. return (float)rand()/((float)(RAND_MAX) + 1.0f);
  155. }
  156. float frand_normal(struct random_normal_distribution * rnd) {
  157. return fclamp(rnd->rd(rnd->gen), rnd->min, rnd->max);
  158. }
  159. float frand_uniform(struct random_uniform_distribution * rnd) {
  160. return rnd->rd(rnd->gen);
  161. }
  162. int clamp(const int v, const int min, const int max) {
  163. return ((v < min) ? (min) : (v > max) ? (max) : v);
  164. }
  165. float fclamp(const float v, const float min, const float max) {
  166. return ((v < min) ? (min) : (v > max) ? (max) : v);
  167. }
  168. void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) {
  169. GGML_ASSERT(tensor->ne[0] == ne0);
  170. GGML_ASSERT(tensor->ne[1] == 1);
  171. GGML_ASSERT(tensor->ne[2] == 1);
  172. GGML_ASSERT(tensor->ne[3] == 1);
  173. }
  174. void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1) {
  175. GGML_ASSERT(tensor->ne[0] == ne0);
  176. GGML_ASSERT(tensor->ne[1] == ne1);
  177. GGML_ASSERT(tensor->ne[2] == 1);
  178. GGML_ASSERT(tensor->ne[3] == 1);
  179. }
  180. void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2) {
  181. GGML_ASSERT(tensor->ne[0] == ne0);
  182. GGML_ASSERT(tensor->ne[1] == ne1);
  183. GGML_ASSERT(tensor->ne[2] == ne2);
  184. GGML_ASSERT(tensor->ne[3] == 1);
  185. }
  186. void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
  187. GGML_ASSERT(tensor->ne[0] == ne0);
  188. GGML_ASSERT(tensor->ne[1] == ne1);
  189. GGML_ASSERT(tensor->ne[2] == ne2);
  190. GGML_ASSERT(tensor->ne[3] == ne3);
  191. }
  192. int64_t get_example_targets_batch(
  193. struct llama_context * lctx,
  194. struct ggml_tensor * tokens_input,
  195. struct ggml_tensor * target_probs,
  196. int64_t example_id,
  197. const size_t * samples_offs,
  198. const size_t * samples_begin,
  199. const size_t * samples_size,
  200. size_t samples_count,
  201. const llama_token * train_data,
  202. size_t n_train_data,
  203. bool separate_with_eos,
  204. bool separate_with_bos,
  205. bool fill_with_next_samples,
  206. bool sample_random_offsets
  207. ) {
  208. GGML_ASSERT(samples_count > 0);
  209. GGML_ASSERT(ggml_is_matrix(tokens_input));
  210. GGML_ASSERT(ggml_is_3d(target_probs));
  211. int64_t n_vocab = target_probs->ne[0];
  212. int64_t n_tokens = tokens_input->ne[0];
  213. int64_t n_batch = tokens_input->ne[1];
  214. GGML_ASSERT(n_vocab == target_probs->ne[0]);
  215. GGML_ASSERT(n_tokens == target_probs->ne[1]);
  216. GGML_ASSERT(n_batch == target_probs->ne[2]);
  217. int64_t used_samples = 0;
  218. ggml_set_f32(target_probs, 0.0f);
  219. llama_token bos = llama_token_bos(llama_get_model(lctx));
  220. llama_token eos = llama_token_eos(llama_get_model(lctx));
  221. // printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
  222. for (int k=0; k<n_batch; ++k) {
  223. // printf("%s: batch %d\n", __func__, k);
  224. size_t sample_idx = (example_id + used_samples) % samples_count;
  225. size_t sample_offs = sample_random_offsets ? samples_offs[sample_idx] : 0;
  226. size_t sample_begin = samples_begin[sample_idx];
  227. size_t sample_size = samples_size[sample_idx];
  228. ++used_samples;
  229. // printf("%s: sample_idx=%zu sample=%zu\n", __func__, sample_idx, sample);
  230. GGML_ASSERT(sample_begin+sample_size-1 < n_train_data);
  231. ggml_set_i32_nd(tokens_input, 0, k, 0, 0, bos);
  232. bool sample_separation_eos = !separate_with_eos;
  233. bool sample_separation_bos = !separate_with_bos;
  234. for (int64_t i=0; i<n_tokens; ++i) {
  235. llama_token token = eos;
  236. if (sample_offs >= sample_size && fill_with_next_samples) {
  237. if (!sample_separation_eos) {
  238. // insert eos token to separate samples
  239. sample_separation_eos = true;
  240. } else if (!sample_separation_bos) {
  241. // insert bos token to separate samples
  242. sample_separation_bos = true;
  243. token = bos;
  244. } else {
  245. // sample separation is done, continue with next sample
  246. sample_separation_eos = !separate_with_eos;
  247. sample_separation_bos = !separate_with_bos;
  248. sample_offs = 0;
  249. sample_idx = (example_id + used_samples) % samples_count;
  250. sample_begin = samples_begin[sample_idx];
  251. sample_size = samples_size[sample_idx];
  252. ++used_samples;
  253. }
  254. }
  255. // note: no else-if here
  256. if (sample_offs < sample_size) {
  257. token = clamp(train_data[sample_begin+sample_offs], 0, (llama_token) (n_vocab - 1));
  258. ++sample_offs;
  259. }
  260. ggml_set_f32_nd(target_probs, token, (int) i, (int) k, 0, +1.0f);
  261. if (i+1<n_tokens) {
  262. ggml_set_i32_nd(tokens_input, (int) (i + 1), (int) k, 0, 0, token);
  263. }
  264. }
  265. }
  266. return used_samples;
  267. }
  268. void mt19937_set_state(std::mt19937& rng, const std::string& rng_state) {
  269. std::stringstream s_rng_state;
  270. s_rng_state.imbue(std::locale::classic());
  271. s_rng_state.exceptions(std::stringstream::failbit);
  272. s_rng_state.str(rng_state);
  273. s_rng_state >> rng;
  274. }
  275. std::string mt19937_get_state(const std::mt19937& rng) {
  276. std::stringstream s_rng_state;
  277. s_rng_state.imbue(std::locale::classic());
  278. s_rng_state << rng;
  279. return s_rng_state.str();
  280. }
  281. std::string mt19937_seed_to_state(unsigned seed) {
  282. std::mt19937 rng(seed);
  283. return mt19937_get_state(rng);
  284. }
  285. std::string shuffle_samples(
  286. const std::string & rng_state,
  287. size_t * shuffled_offs,
  288. size_t * shuffled_begins,
  289. size_t * shuffled_sizes,
  290. const size_t * begins,
  291. const size_t * sizes,
  292. size_t count) {
  293. if (count == 0) return rng_state;
  294. std::mt19937 rng;
  295. mt19937_set_state(rng, rng_state);
  296. // sort indices by random value for each index
  297. std::vector<size_t> idcs;
  298. {
  299. std::vector<unsigned> rnd;
  300. idcs.resize(count);
  301. rnd.resize(count);
  302. for (unsigned i=0; i<count; ++i) {
  303. idcs[i] = i;
  304. rnd[i] = rng();
  305. }
  306. std::sort(idcs.begin(), idcs.end(), [&rnd](size_t a, size_t b){
  307. // stable sort for reproducibility
  308. return (rnd[a] == rnd[b]) ? (a < b) : (rnd[a] < rnd[b]);
  309. });
  310. }
  311. // create random offsets
  312. for (unsigned i=0; i<count; ++i) {
  313. shuffled_offs[i] = (size_t) ((sizes[idcs[i]] - 1) * ((double) rng() / (double) (rng.max()-1)));
  314. }
  315. // reorder begins and sizes by sorted indices
  316. for (unsigned i=0; i<count; ++i) {
  317. shuffled_begins[i] = begins[idcs[i]];
  318. }
  319. for (unsigned i=0; i<count; ++i) {
  320. shuffled_sizes[i] = sizes[idcs[i]];
  321. }
  322. return mt19937_get_state(rng);
  323. }
  324. size_t hash_combine(size_t h1, size_t h2) {
  325. return h1 ^ (h2 << 1);
  326. }
  327. size_t compute_samples_hash(const char* fn, const size_t* samples_begin, const size_t* samples_size, size_t sample_count) {
  328. std::hash<std::string> h_string;
  329. std::hash<unsigned long long> h_ull;
  330. size_t h = h_string(std::string(fn));
  331. h = hash_combine(h, h_ull((unsigned long long) sample_count));
  332. for (size_t i=0; i< sample_count; ++i) {
  333. h = hash_combine(h, h_ull((unsigned long long) samples_begin[i]));
  334. h = hash_combine(h, h_ull((unsigned long long) samples_size[i]));
  335. }
  336. return h;
  337. }
  338. std::string replace_str(const char * s, const char * needle, const char * replacement) {
  339. std::string str = s;
  340. size_t pos = str.find(needle);
  341. if (pos != std::string::npos) {
  342. str.replace(pos, strlen(needle), replacement);
  343. }
  344. return str;
  345. }
  346. void print_duration(double fmillis) {
  347. if (fmillis < 1000.0f) {
  348. printf("%.1fms", (float) fmillis);
  349. return;
  350. }
  351. const int64_t one_sec = 1000;
  352. const int64_t one_min = one_sec * 60;
  353. const int64_t one_hour = one_min * 60;
  354. const int64_t one_day = one_hour * 24;
  355. int64_t millis = (int64_t) fmillis;
  356. int64_t days = millis/one_day;
  357. int64_t hours = (millis - days*one_day)/one_hour;
  358. int64_t minutes = (millis - days*one_day - hours*one_hour)/one_min;
  359. int64_t seconds = (millis - days*one_day - hours*one_hour - minutes*one_min)/one_sec;
  360. // to print int64_t either cast to (long long int) or use macro PRId64 from <inttypes.h>
  361. if (days > 0) {
  362. printf("%lldd ", (long long int) days);
  363. }
  364. printf("%02lld:%02lld:%02lld", (long long int) hours, (long long int) minutes, (long long int) seconds);
  365. }
  366. float cosine_decay(int64_t step, int64_t decay_steps, float minimum) {
  367. if (step > decay_steps) {
  368. step = decay_steps;
  369. }
  370. const float cosine_decay = 0.50f*(1.0f + cosf(3.14159265359f*step/decay_steps));
  371. const float decay = (1 - minimum)*cosine_decay + minimum;
  372. return decay;
  373. }
  374. float cosine_decay_restart(int64_t step, int64_t decay_steps, float minimum, float restart_step_mult) {
  375. while (step > decay_steps) {
  376. step -= decay_steps;
  377. decay_steps = (int64_t) (restart_step_mult * decay_steps);
  378. }
  379. return cosine_decay(step, decay_steps, minimum);
  380. }
  381. float learning_schedule(
  382. int64_t step,
  383. int64_t warmup_steps,
  384. int64_t cos_decay_steps,
  385. float learning_rate,
  386. float overall_minimum,
  387. float cos_decay_minimum,
  388. float cos_decay_restart_step_mult,
  389. bool enable_restart) {
  390. float result =
  391. (step < warmup_steps)
  392. ? (float) step / (float) warmup_steps
  393. : enable_restart
  394. ? cosine_decay_restart(
  395. step - warmup_steps,
  396. cos_decay_steps,
  397. cos_decay_minimum,
  398. cos_decay_restart_step_mult)
  399. : cosine_decay(
  400. step,
  401. cos_decay_steps,
  402. cos_decay_minimum);
  403. float min = overall_minimum / learning_rate;
  404. result = min + result * (1.0f - min);
  405. return result;
  406. }
  407. static bool are_same_layout(struct ggml_tensor * a, struct ggml_tensor * b) {
  408. GGML_ASSERT(a != NULL);
  409. GGML_ASSERT(b != NULL);
  410. GGML_ASSERT(a->type == b->type);
  411. GGML_ASSERT(ggml_are_same_shape(a, b));
  412. GGML_ASSERT(ggml_is_contiguous(a) && ggml_is_contiguous(b));
  413. return true;
  414. }
  415. void copy_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name) {
  416. if (dst == NULL) {
  417. return;
  418. }
  419. struct ggml_tensor * t = ggml_get_tensor(ctx, name);
  420. GGML_ASSERT(are_same_layout(dst, t));
  421. memcpy(dst->data, t->data, ggml_nbytes(t));
  422. if (strlen(ggml_get_name(dst)) == 0) {
  423. ggml_set_name(dst, name);
  424. }
  425. }
  426. // gguf constants
  427. static const char * LLM_KV_OPTIMIZER_TYPE = "optimizer.type";
  428. static const char * LLM_KV_OPTIMIZER_TYPE_ADAM = "adam";
  429. static const char * LLM_KV_OPTIMIZER_TYPE_LBFGS = "lbfgs";
  430. static const char * LLM_KV_OPTIMIZER_FILE_VERSION = "optimizer.file_version";
  431. static const char * LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT = "optimizer.convergence_past_count";
  432. static const char * LLM_KV_OPTIMIZER_PARAMETER_COUNT = "optimizer.parameter_count";
  433. static const char * LLM_KV_OPTIMIZER_ITERATION_COUNT = "optimizer.iteration_count";
  434. static const char * LLM_KV_OPTIMIZER_JUST_INITIALIZED = "optimizer.just_initialized";
  435. static const char * LLM_KV_OPTIMIZER_ADAM_BEST_LOSS = "optimizer.adam.best_loss";
  436. static const char * LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS = "optimizer.adam.previous_loss";
  437. static const char * LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT = "optimizer.adam.no_improvement_count";
  438. static const char * LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT = "optimizer.lbfgs.approx_hessian_count";
  439. static const char * LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS = "optimizer.lbfgs.best_loss";
  440. static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP = "optimizer.lbfgs.line_search_step";
  441. static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J = "optimizer.lbfgs.line_search_j";
  442. static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K = "optimizer.lbfgs.line_search_k";
  443. static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END = "optimizer.lbfgs.line_search_end";
  444. static const char * LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT = "optimizer.lbfgs.no_improvement_count";
  445. static const char * LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS = "optimizer.adam.first_moments";
  446. static const char * LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS = "optimizer.adam.second_moments";
  447. static const char * LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES = "optimizer.adam.past_loss_values";
  448. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS = "optimizer.lbfgs.current_parameters";
  449. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS = "optimizer.lbfgs.previous_parameters";
  450. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS = "optimizer.lbfgs.current_gradients";
  451. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS = "optimizer.lbfgs.previous_gradients";
  452. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION = "optimizer.lbfgs.search_direction";
  453. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES = "optimizer.lbfgs.past_loss_values";
  454. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA = "optimizer.lbfgs.memory_alpha";
  455. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.lbfgs.memory_ys";
  456. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s";
  457. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y";
  458. static const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version";
  459. static const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count";
  460. static const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count";
  461. static const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count";
  462. static const char * LLM_KV_TRAINING_EPOCH_COUNT = "training.epoch_count";
  463. static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH = "training.shuffle.samples_hash";
  464. static const char * LLM_KV_TRAINING_SHUFFLE_RNG_STATE = "training.shuffle.rng_state";
  465. static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT = "training.shuffle.sample_count";
  466. static const char * LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE = "training.shuffle.next_sample";
  467. #define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
  468. { \
  469. const std::string skey(key); \
  470. const int kid = gguf_find_key(ctx, skey.c_str()); \
  471. if (kid >= 0) { \
  472. enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
  473. if (ktype != (type)) { \
  474. die_fmt("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype)); \
  475. } \
  476. (dst) = func(ctx, kid); \
  477. } else if (req) { \
  478. die_fmt("key not found in model: %s", skey.c_str()); \
  479. } \
  480. }
  481. void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt) {
  482. // NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
  483. uint32_t file_version;
  484. GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_FILE_VERSION);
  485. GGML_ASSERT(file_version == 0);
  486. GGUF_GET_KEY(fctx, opt->params.past, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT);
  487. GGUF_GET_KEY(fctx, opt->iter, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ITERATION_COUNT);
  488. GGUF_GET_KEY(fctx, opt->just_initialized, gguf_get_val_bool, GGUF_TYPE_BOOL, true, LLM_KV_OPTIMIZER_JUST_INITIALIZED);
  489. uint64_t nx;
  490. GGUF_GET_KEY(fctx, nx, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_OPTIMIZER_PARAMETER_COUNT);
  491. opt->nx = (size_t) nx;
  492. // don't call ggml_opt_init until optimizer type and optimizer specific parameters are know
  493. std::string opt_type;
  494. GGUF_GET_KEY(fctx, opt_type, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_OPTIMIZER_TYPE);
  495. if (opt_type == LLM_KV_OPTIMIZER_TYPE_ADAM) {
  496. opt->params.type = GGML_OPT_TYPE_ADAM;
  497. GGUF_GET_KEY(fctx, opt->adam.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS);
  498. GGUF_GET_KEY(fctx, opt->adam.fx_prev, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS);
  499. GGUF_GET_KEY(fctx, opt->adam.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT);
  500. ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
  501. copy_tensor_by_name(opt->adam.m, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
  502. copy_tensor_by_name(opt->adam.v, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
  503. copy_tensor_by_name(opt->adam.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
  504. } else if (opt_type == LLM_KV_OPTIMIZER_TYPE_LBFGS) {
  505. opt->params.type = GGML_OPT_TYPE_LBFGS;
  506. GGUF_GET_KEY(fctx, opt->params.lbfgs.m, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT);
  507. GGUF_GET_KEY(fctx, opt->lbfgs.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS);
  508. GGUF_GET_KEY(fctx, opt->lbfgs.step, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP);
  509. GGUF_GET_KEY(fctx, opt->lbfgs.j, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J);
  510. GGUF_GET_KEY(fctx, opt->lbfgs.k, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K);
  511. GGUF_GET_KEY(fctx, opt->lbfgs.end, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END);
  512. GGUF_GET_KEY(fctx, opt->lbfgs.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT);
  513. ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
  514. copy_tensor_by_name(opt->lbfgs.x, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
  515. copy_tensor_by_name(opt->lbfgs.xp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
  516. copy_tensor_by_name(opt->lbfgs.g, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
  517. copy_tensor_by_name(opt->lbfgs.gp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
  518. copy_tensor_by_name(opt->lbfgs.d, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
  519. copy_tensor_by_name(opt->lbfgs.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
  520. copy_tensor_by_name(opt->lbfgs.lmal, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
  521. copy_tensor_by_name(opt->lbfgs.lmys, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
  522. copy_tensor_by_name(opt->lbfgs.lms, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
  523. copy_tensor_by_name(opt->lbfgs.lmy, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
  524. } else {
  525. die("unknown optimizer type\n");
  526. }
  527. }
  528. void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt) {
  529. gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_FILE_VERSION, 0);
  530. gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT, opt->params.past);
  531. gguf_set_val_u64(fctx, LLM_KV_OPTIMIZER_PARAMETER_COUNT, (uint64_t) opt->nx);
  532. gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ITERATION_COUNT, opt->iter);
  533. gguf_set_val_bool(fctx, LLM_KV_OPTIMIZER_JUST_INITIALIZED, opt->just_initialized);
  534. switch (opt->params.type) {
  535. case GGML_OPT_TYPE_ADAM:
  536. {
  537. gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_ADAM);
  538. gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS, opt->adam.fx_best);
  539. gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS, opt->adam.fx_prev);
  540. gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT, opt->adam.n_no_improvement);
  541. ggml_set_name(opt->adam.m, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
  542. ggml_set_name(opt->adam.v, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
  543. if (opt->adam.pf) {
  544. ggml_set_name(opt->adam.pf, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
  545. }
  546. gguf_add_tensor(fctx, opt->adam.m);
  547. gguf_add_tensor(fctx, opt->adam.v);
  548. if (opt->adam.pf) {
  549. gguf_add_tensor(fctx, opt->adam.pf);
  550. }
  551. } break;
  552. case GGML_OPT_TYPE_LBFGS:
  553. {
  554. gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS);
  555. gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, opt->params.lbfgs.m);
  556. gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS, opt->lbfgs.fx_best);
  557. gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP, opt->lbfgs.step);
  558. gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J, opt->lbfgs.j);
  559. gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K, opt->lbfgs.k);
  560. gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END, opt->lbfgs.end);
  561. gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT, opt->lbfgs.n_no_improvement);
  562. ggml_set_name(opt->lbfgs.x, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
  563. ggml_set_name(opt->lbfgs.xp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
  564. ggml_set_name(opt->lbfgs.g, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
  565. ggml_set_name(opt->lbfgs.gp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
  566. ggml_set_name(opt->lbfgs.d, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
  567. if (opt->lbfgs.pf) {
  568. ggml_set_name(opt->lbfgs.pf, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
  569. }
  570. ggml_set_name(opt->lbfgs.lmal, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
  571. ggml_set_name(opt->lbfgs.lmys, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
  572. ggml_set_name(opt->lbfgs.lms, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
  573. ggml_set_name(opt->lbfgs.lmy, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
  574. gguf_add_tensor(fctx, opt->lbfgs.x);
  575. gguf_add_tensor(fctx, opt->lbfgs.xp);
  576. gguf_add_tensor(fctx, opt->lbfgs.g);
  577. gguf_add_tensor(fctx, opt->lbfgs.gp);
  578. gguf_add_tensor(fctx, opt->lbfgs.d);
  579. if (opt->lbfgs.pf) {
  580. gguf_add_tensor(fctx, opt->lbfgs.pf);
  581. }
  582. gguf_add_tensor(fctx, opt->lbfgs.lmal);
  583. gguf_add_tensor(fctx, opt->lbfgs.lmys);
  584. gguf_add_tensor(fctx, opt->lbfgs.lms);
  585. gguf_add_tensor(fctx, opt->lbfgs.lmy);
  586. } break;
  587. }
  588. }
  589. bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train) {
  590. if (gguf_find_key(fctx, LLM_KV_TRAINING_FILE_VERSION) < 0) {
  591. return false;
  592. }
  593. uint32_t file_version;
  594. GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION);
  595. GGML_ASSERT(file_version <= 1);
  596. if (file_version == 0) {
  597. GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT);
  598. GGUF_GET_KEY(fctx, train->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT);
  599. GGUF_GET_KEY(fctx, train->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT);
  600. } else if (file_version == 1) {
  601. GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_ITERATION_COUNT);
  602. GGUF_GET_KEY(fctx, train->train_samples, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_SAMPLE_COUNT);
  603. GGUF_GET_KEY(fctx, train->train_tokens, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_TOKEN_COUNT);
  604. GGUF_GET_KEY(fctx, train->train_epochs, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_EPOCH_COUNT);
  605. GGUF_GET_KEY(fctx, train->shuffle_samples_hash, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH);
  606. GGUF_GET_KEY(fctx, train->shuffle_rng_state_current, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_SHUFFLE_RNG_STATE);
  607. GGUF_GET_KEY(fctx, train->shuffle_sample_count, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT);
  608. GGUF_GET_KEY(fctx, train->shuffle_next_sample, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE);
  609. }
  610. load_opt_context_gguf(fctx, f_ggml_ctx, train->opt);
  611. return true;
  612. }
  613. void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train) {
  614. gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 1);
  615. gguf_set_val_u64(fctx, LLM_KV_TRAINING_ITERATION_COUNT, train->train_its);
  616. gguf_set_val_u64(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, train->train_samples);
  617. gguf_set_val_u64(fctx, LLM_KV_TRAINING_TOKEN_COUNT, train->train_tokens);
  618. gguf_set_val_u64(fctx, LLM_KV_TRAINING_EPOCH_COUNT, train->train_epochs);
  619. gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH, (uint64_t) train->shuffle_samples_hash);
  620. gguf_set_val_str(fctx, LLM_KV_TRAINING_SHUFFLE_RNG_STATE, train->shuffle_rng_state_current.c_str());
  621. gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT, (uint64_t) train->shuffle_sample_count);
  622. gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE, (uint64_t) train->shuffle_next_sample);
  623. save_opt_context_gguf(fctx, train->opt);
  624. }
  625. struct llama_file {
  626. // use FILE * so we don't have to re-open the file to mmap
  627. FILE * fp;
  628. size_t size;
  629. llama_file(const char * fname, const char * mode) {
  630. fp = std::fopen(fname, mode);
  631. if (fp == NULL) {
  632. size = 0;
  633. } else {
  634. seek(0, SEEK_END);
  635. size = tell();
  636. seek(0, SEEK_SET);
  637. }
  638. }
  639. size_t tell() const {
  640. #ifdef _WIN32
  641. __int64 ret = _ftelli64(fp);
  642. #else
  643. long ret = std::ftell(fp);
  644. #endif
  645. GGML_ASSERT(ret != -1); // this really shouldn't fail
  646. return (size_t) ret;
  647. }
  648. void seek(size_t offset, int whence) {
  649. #ifdef _WIN32
  650. int ret = _fseeki64(fp, (__int64) offset, whence);
  651. #else
  652. int ret = std::fseek(fp, (long) offset, whence);
  653. #endif
  654. GGML_ASSERT(ret == 0); // same
  655. }
  656. void read_raw(void * ptr, size_t size) {
  657. if (size == 0) {
  658. return;
  659. }
  660. errno = 0;
  661. std::size_t ret = std::fread(ptr, size, 1, fp);
  662. if (ferror(fp)) {
  663. die_fmt("read error: %s", strerror(errno));
  664. }
  665. if (ret != 1) {
  666. die("unexpectedly reached end of file");
  667. }
  668. }
  669. std::uint32_t read_u32() {
  670. std::uint32_t ret;
  671. read_raw(&ret, sizeof(ret));
  672. return ret;
  673. }
  674. std::string read_string(std::uint32_t len) {
  675. std::vector<char> chars(len);
  676. read_raw(chars.data(), len);
  677. return std::string(chars.data(), len);
  678. }
  679. void write_raw(const void * ptr, size_t size) {
  680. if (size == 0) {
  681. return;
  682. }
  683. errno = 0;
  684. size_t ret = std::fwrite(ptr, size, 1, fp);
  685. if (ret != 1) {
  686. die_fmt("write error: %s", strerror(errno));
  687. }
  688. }
  689. void write_u32(std::uint32_t val) {
  690. write_raw(&val, sizeof(val));
  691. }
  692. ~llama_file() {
  693. if (fp) {
  694. std::fclose(fp);
  695. }
  696. }
  697. };
  698. static size_t utf8_len(char src) {
  699. const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
  700. uint8_t highbits = static_cast<uint8_t>(src) >> 4;
  701. return lookup[highbits];
  702. }
  703. // mark each byte with its utf8 unit number.
  704. // returns the number of utf8 characters.
  705. // e.g. when bytes == '\x61\xD0\xB0\x62',
  706. // then utf8_units will become [0,0,1,0]
  707. // utf8_nunits will become [1,2,2,1] and 3 is returned.
  708. // bytes where utf8_units is zero, are the begin of an utf8 character.
  709. static size_t mark_utf8_units(const char* bytes, int * utf8_units, int * utf8_nunits, size_t count) {
  710. size_t offs = 0;
  711. size_t count_utf8 = 0;
  712. while(offs < count) {
  713. int len = (int) utf8_len(bytes[offs]);
  714. for (int i=0; i<len; ++i) {
  715. utf8_units[offs+i] = i;
  716. utf8_nunits[offs+i] = len;
  717. }
  718. offs += len;
  719. ++count_utf8;
  720. }
  721. return count_utf8;
  722. }
  723. size_t tokenize_file(
  724. struct llama_context * lctx,
  725. const char * filename,
  726. const std::string & sample_start,
  727. bool include_sample_start,
  728. bool overlapping_samples,
  729. unsigned context_length,
  730. std::vector<llama_token> & out_tokens,
  731. std::vector<size_t> & out_samples_begin,
  732. std::vector<size_t> & out_samples_size) {
  733. struct llama_file f(filename, "rb");
  734. if (f.size == 0) {
  735. out_tokens.clear();
  736. out_samples_begin.clear();
  737. out_samples_size.clear();
  738. printf("%s: warning: empty or not existing training data file '%s'\n",
  739. __func__, filename);
  740. return out_tokens.size();
  741. }
  742. // account for possible leading whitespace that will be added by tokenizer
  743. // e.g. '\t' will be tokenized by llama spm tokenizer to [29871, 12]
  744. const int n_max_tokens_overhead = 1;
  745. std::vector<char> buf;
  746. buf.resize(f.size);
  747. f.read_raw(buf.data(), f.size);
  748. std::vector<int> utf8_units;
  749. std::vector<int> utf8_nunits;
  750. utf8_units.resize(buf.size());
  751. utf8_nunits.resize(buf.size());
  752. mark_utf8_units(buf.data(), utf8_units.data(), utf8_nunits.data(), buf.size());
  753. if (sample_start.size() == 0) {
  754. // tokenize all data at once
  755. out_tokens.resize(buf.size() + n_max_tokens_overhead);
  756. int n_tokens = llama_tokenize(
  757. llama_get_model(lctx),
  758. buf.data(),
  759. (int) buf.size(),
  760. out_tokens.data(),
  761. (int) out_tokens.size(),
  762. false, false);
  763. if (n_tokens < 0) {
  764. out_tokens.resize(-n_tokens);
  765. n_tokens = llama_tokenize(
  766. llama_get_model(lctx),
  767. buf.data(),
  768. (int) buf.size(),
  769. out_tokens.data(),
  770. (int) out_tokens.size(),
  771. false, false);
  772. }
  773. if (n_tokens >= 0) {
  774. out_tokens.resize(n_tokens);
  775. }
  776. // generate sample starts at all token positions
  777. out_samples_begin.clear();
  778. out_samples_begin.push_back(0);
  779. out_samples_size.push_back(std::min((size_t) context_length, out_tokens.size()));
  780. size_t end = (out_tokens.size() >= context_length) ? (out_tokens.size() - context_length) : 0;
  781. for (size_t sample_begin = 1; sample_begin < end; ++sample_begin) {
  782. out_samples_begin.push_back(sample_begin);
  783. out_samples_size.push_back(context_length);
  784. }
  785. } else {
  786. // split data into samples and tokenize each sample
  787. std::string data_str(buf.data(), buf.size());
  788. out_samples_begin.clear();
  789. out_samples_size.clear();
  790. out_tokens.clear();
  791. // find all positions of pattern sample_start
  792. size_t sample_begin = data_str.find(sample_start, 0);
  793. while (sample_begin != std::string::npos) {
  794. out_samples_begin.push_back(sample_begin);
  795. const size_t search_start = sample_begin + sample_start.size();
  796. sample_begin = data_str.find(sample_start, search_start);
  797. }
  798. if (out_samples_begin.size() == 0) {
  799. printf("%s: warning: sample start pattern '%s' not found. inserting single sample at data begin\n",
  800. __func__, sample_start.c_str());
  801. out_samples_begin.push_back(0);
  802. }
  803. out_samples_size.resize(out_samples_begin.size(), 0);
  804. std::vector<char> buf_sample;
  805. std::vector<llama_token> tok_sample;
  806. const size_t sample_begin_offset = (include_sample_start ? 0 : sample_start.size());
  807. size_t found_too_big_sample = 0;
  808. size_t found_too_small_sample = 0;
  809. size_t found_empty_sample = 0;
  810. size_t found_min_sample_size = SIZE_MAX;
  811. size_t found_max_sample_size = 0;
  812. size_t max_token_text_size = 0;
  813. int n_vocab = llama_n_vocab(llama_get_model(lctx));
  814. for (llama_token token=0; token < n_vocab; ++token) {
  815. max_token_text_size = std::max(
  816. max_token_text_size,
  817. strlen(llama_token_get_text(llama_get_model(lctx), token)));
  818. }
  819. // upper bound of context byte length.
  820. // strings with this byte length should always tokenize to at least context_length tokens.
  821. size_t context_byte_len = max_token_text_size*context_length;
  822. for (unsigned i=0; i<out_samples_begin.size(); ++i) {
  823. // determine sample begin and end from pattern positions
  824. size_t sample_begin = out_samples_begin[i] + sample_begin_offset;
  825. size_t sample_end = overlapping_samples
  826. ? std::min(
  827. data_str.size(),
  828. sample_begin + context_byte_len)
  829. : (i+1 < out_samples_begin.size()
  830. ? out_samples_begin[i+1]
  831. : data_str.size());
  832. if (sample_end < utf8_units.size() && utf8_units[sample_end] > 0) {
  833. // sample end is in the middle of an utf8 character.
  834. // advance sample_end to the begin of the next utf8 character.
  835. sample_end += utf8_nunits[sample_end] - utf8_units[sample_end];
  836. }
  837. size_t sample_size = sample_end - sample_begin;
  838. if (sample_size == 0) {
  839. ++found_empty_sample;
  840. }
  841. if (sample_size > 0) {
  842. // llama_tokenize expects zero terminated string,
  843. // copy sample into buffer and zero terminate it.
  844. buf_sample.resize(sample_size);
  845. memcpy(buf_sample.data(), data_str.data() + sample_begin, sample_size);
  846. // printf("sample: '%s'\n", buf_sample.data());
  847. // tokenize the sample
  848. tok_sample.resize(buf_sample.size() + n_max_tokens_overhead);
  849. int n_tokens = llama_tokenize(llama_get_model(lctx),
  850. buf_sample.data(),
  851. (int) buf_sample.size(),
  852. tok_sample.data(),
  853. (int) tok_sample.size(),
  854. false, false);
  855. if (n_tokens < 0) {
  856. tok_sample.resize(-n_tokens);
  857. n_tokens = llama_tokenize(llama_get_model(lctx),
  858. buf_sample.data(),
  859. (int) buf_sample.size(),
  860. tok_sample.data(),
  861. (int) tok_sample.size(),
  862. false, false);
  863. GGML_ASSERT(n_tokens >= 0);
  864. }
  865. GGML_ASSERT(n_tokens <= (int) tok_sample.size());
  866. if ((size_t) n_tokens > context_length) {
  867. ++found_too_big_sample;
  868. } else if ((size_t) n_tokens < context_length) {
  869. ++found_too_small_sample;
  870. }
  871. found_max_sample_size = std::max(found_max_sample_size, (size_t) n_tokens);
  872. found_min_sample_size = std::min(found_min_sample_size, (size_t) n_tokens);
  873. // write out tokens, start and size of sample
  874. // overwrite the string start position with the token start position
  875. out_samples_begin[i] = out_tokens.size();
  876. out_samples_size[i] = (size_t) n_tokens;
  877. out_tokens.insert(out_tokens.end(), tok_sample.begin(), tok_sample.begin() + n_tokens);
  878. } else {
  879. out_samples_begin[i] = out_tokens.size();
  880. out_samples_size[i] = 0;
  881. }
  882. }
  883. if (found_too_big_sample > 0) {
  884. printf("%s: warning: found %zu samples (max length %zu) that exceed context length of %u. samples will be cut off.\n",
  885. __func__, found_too_big_sample, found_max_sample_size, context_length);
  886. }
  887. if (found_too_small_sample > 0) {
  888. printf("%s: warning: found %zu samples (min length %zu) that are shorter than context length of %u.\n",
  889. __func__, found_too_small_sample, found_min_sample_size, context_length);
  890. }
  891. if (found_empty_sample) {
  892. printf("%s: warning: found %zu empty samples.\n",
  893. __func__, found_empty_sample);
  894. }
  895. }
  896. printf("%s: total number of samples: %zu\n",
  897. __func__, out_samples_begin.size());
  898. GGML_ASSERT(out_samples_begin.size() == out_samples_size.size());
  899. return out_tokens.size();
  900. }
  901. std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration) {
  902. std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
  903. return replace_str(filename, pattern_it, sit.c_str());
  904. }
  905. struct train_params_common get_default_train_params_common() {
  906. struct train_params_common params;
  907. params.fn_train_data = "shakespeare.txt";
  908. params.fn_checkpoint_in = "checkpoint.gguf";
  909. params.fn_checkpoint_out = "checkpoint-ITERATION.gguf";
  910. params.pattern_fn_it = "ITERATION";
  911. params.fn_latest = "LATEST";
  912. params.print_usage = false;
  913. params.save_every = 10;
  914. params.seed = -1;
  915. params.n_ctx = 128;
  916. params.n_threads = 6;
  917. params.n_batch = 8;
  918. params.n_gradient_accumulation = 1;
  919. params.n_epochs = -1;
  920. params.n_gpu_layers = 0;
  921. params.custom_n_ctx = false;
  922. params.use_flash = false;
  923. params.use_checkpointing = true;
  924. params.sample_start = "";
  925. params.include_sample_start = false;
  926. params.escape = false;
  927. params.overlapping_samples = false;
  928. params.fill_with_next_samples = false;
  929. params.separate_with_eos = false;
  930. params.separate_with_bos = true;
  931. params.sample_random_offsets = false;
  932. params.force_reshuffle = false;
  933. params.opt_past = 0;
  934. params.opt_delta = 1e-5f;
  935. params.opt_max_no_improvement = 0;
  936. params.warmup = 100;
  937. params.cos_decay_steps = 1000;
  938. params.cos_decay_restart = 1.1f;
  939. params.cos_decay_min = 0.1f;
  940. params.enable_restart = false;
  941. params.adam_n_iter = 256;
  942. params.adam_alpha = 1e-3f;
  943. params.adam_min_alpha = 0;
  944. params.adam_decay = 1e-1f;
  945. params.adam_decay_min_ndim = 2;
  946. params.adam_beta1 = 0.9f;
  947. params.adam_beta2 = 0.999f;
  948. params.adam_gclip = 1.0f;
  949. params.adam_eps_f = 0.0f;
  950. return params;
  951. }
  952. void print_common_train_usage(int /*argc*/, char ** /*argv*/, const struct train_params_common * params) {
  953. // fprintf(stderr, "usage: %s [options]\n", argv[0]);
  954. // fprintf(stderr, "\n");
  955. // fprintf(stderr, "options:\n");
  956. // fprintf(stderr, " -h, --help show this help message and exit\n");
  957. fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data);
  958. fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
  959. fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
  960. fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it);
  961. fprintf(stderr, " --fn-latest STR string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest);
  962. fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every);
  963. fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n");
  964. fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
  965. fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
  966. fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
  967. fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation);
  968. fprintf(stderr, " --sample-start STR Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str());
  969. fprintf(stderr, " --include-sample-start Include the sample start in the samples. (default off)\n");
  970. fprintf(stderr, " --escape process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
  971. fprintf(stderr, " --overlapping-samples Samples may overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n");
  972. fprintf(stderr, " --fill-with-next-samples Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n");
  973. fprintf(stderr, " --separate-with-eos When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : "");
  974. fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : "");
  975. fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : "");
  976. fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : "");
  977. fprintf(stderr, " --sample-random-offsets Use samples beginning at random offsets. Together with fill-with-next-samples this may help for training endless text generation.%s\n", params->sample_random_offsets ? " (default)" : "");
  978. fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n");
  979. fprintf(stderr, " --no-flash Don't use flash attention \n");
  980. fprintf(stderr, " --use-flash Use flash attention (default)\n");
  981. fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
  982. fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n");
  983. fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup);
  984. fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
  985. fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
  986. fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min);
  987. fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : "");
  988. fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : "");
  989. fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
  990. fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
  991. fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
  992. fprintf(stderr, " --epochs N Maximum number epochs to process. (default %d)\n", params->n_epochs);
  993. fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
  994. fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
  995. fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha);
  996. fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
  997. fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim);
  998. fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
  999. fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
  1000. fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
  1001. fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
  1002. fprintf(stderr, " -ngl N, --n-gpu-layers N Number of model layers to offload to GPU (default %d)", params->n_gpu_layers);
  1003. fprintf(stderr, "\n");
  1004. }
  1005. bool consume_common_train_arg(
  1006. int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param
  1007. ) {
  1008. int& i = *idx;
  1009. std::string arg = argv[i];
  1010. const std::string arg_prefix = "--";
  1011. if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
  1012. std::replace(arg.begin(), arg.end(), '_', '-');
  1013. }
  1014. if (arg == "--train-data") {
  1015. if (++i >= argc) {
  1016. *invalid_param = true;
  1017. return true;
  1018. }
  1019. params->fn_train_data = argv[i];
  1020. } else if (arg == "--checkpoint-in") {
  1021. if (++i >= argc) {
  1022. *invalid_param = true;
  1023. return true;
  1024. }
  1025. params->fn_checkpoint_in = argv[i];
  1026. } else if (arg == "--checkpoint-out") {
  1027. if (++i >= argc) {
  1028. *invalid_param = true;
  1029. return true;
  1030. }
  1031. params->fn_checkpoint_out = argv[i];
  1032. } else if (arg == "--pattern-fn-it") {
  1033. if (++i >= argc) {
  1034. *invalid_param = true;
  1035. return true;
  1036. }
  1037. params->pattern_fn_it = argv[i];
  1038. } else if (arg == "--fn-latest") {
  1039. if (++i >= argc) {
  1040. *invalid_param = true;
  1041. return true;
  1042. }
  1043. params->fn_latest = argv[i];
  1044. } else if (arg == "--save-every") {
  1045. if (++i >= argc) {
  1046. *invalid_param = true;
  1047. return true;
  1048. }
  1049. params->save_every = std::stoi(argv[i]);
  1050. } else if (arg == "-s" || arg == "--seed") {
  1051. if (++i >= argc) {
  1052. *invalid_param = true;
  1053. return true;
  1054. }
  1055. params->seed = std::stoi(argv[i]);
  1056. } else if (arg == "-c" || arg == "--ctx") {
  1057. if (++i >= argc) {
  1058. *invalid_param = true;
  1059. return true;
  1060. }
  1061. params->n_ctx = std::stoi(argv[i]);
  1062. params->custom_n_ctx = true;
  1063. } else if (arg == "-t" || arg == "--threads") {
  1064. if (++i >= argc) {
  1065. *invalid_param = true;
  1066. return true;
  1067. }
  1068. params->n_threads = std::stoi(argv[i]);
  1069. } else if (arg == "-b" || arg == "--batch") {
  1070. if (++i >= argc) {
  1071. *invalid_param = true;
  1072. return true;
  1073. }
  1074. params->n_batch = std::stoi(argv[i]);
  1075. } else if (arg == "--grad-acc") {
  1076. if (++i >= argc) {
  1077. *invalid_param = true;
  1078. return true;
  1079. }
  1080. params->n_gradient_accumulation = std::max(1, std::stoi(argv[i]));
  1081. } else if (arg == "--sample-start") {
  1082. if (++i >= argc) {
  1083. *invalid_param = true;
  1084. return true;
  1085. }
  1086. params->sample_start = std::string(argv[i]);
  1087. } else if (arg == "--escape") {
  1088. params->escape = true;
  1089. } else if (arg == "--include-sample-start") {
  1090. params->include_sample_start = true;
  1091. } else if (arg == "--overlapping-samples") {
  1092. params->overlapping_samples = true;
  1093. } else if (arg == "--fill-with-next-samples") {
  1094. params->fill_with_next_samples = true;
  1095. } else if (arg == "--separate-with-eos") {
  1096. params->separate_with_eos = true;
  1097. } else if (arg == "--separate-with-bos") {
  1098. params->separate_with_bos = true;
  1099. } else if (arg == "--no-separate-with-eos") {
  1100. params->separate_with_eos = false;
  1101. } else if (arg == "--no-separate-with-bos") {
  1102. params->separate_with_bos = false;
  1103. } else if (arg == "--sample-random-offsets") {
  1104. params->sample_random_offsets = true;
  1105. } else if (arg == "--force-reshuffle") {
  1106. params->force_reshuffle = true;
  1107. } else if (arg == "--no-flash") {
  1108. params->use_flash = false;
  1109. } else if (arg == "--use-flash") {
  1110. params->use_flash = true;
  1111. } else if (arg == "--no-checkpointing") {
  1112. params->use_checkpointing = false;
  1113. } else if (arg == "--use-checkpointing") {
  1114. params->use_checkpointing = true;
  1115. } else if (arg == "--warmup") {
  1116. if (++i >= argc) {
  1117. *invalid_param = true;
  1118. return true;
  1119. }
  1120. params->warmup = std::stoi(argv[i]);
  1121. } else if (arg == "--cos-decay-steps") {
  1122. if (++i >= argc) {
  1123. *invalid_param = true;
  1124. return true;
  1125. }
  1126. params->cos_decay_steps = std::stoi(argv[i]);
  1127. } else if (arg == "--cos-decay-restart") {
  1128. if (++i >= argc) {
  1129. *invalid_param = true;
  1130. return true;
  1131. }
  1132. params->cos_decay_restart = std::stof(argv[i]);
  1133. } else if (arg == "--cos-decay-min") {
  1134. if (++i >= argc) {
  1135. *invalid_param = true;
  1136. return true;
  1137. }
  1138. params->cos_decay_min = std::stof(argv[i]);
  1139. } else if (arg == "--enable-restart") {
  1140. params->enable_restart = true;
  1141. } else if (arg == "--disable-restart") {
  1142. params->enable_restart = false;
  1143. } else if (arg == "--opt-past") {
  1144. if (++i >= argc) {
  1145. *invalid_param = true;
  1146. return true;
  1147. }
  1148. params->opt_past = std::stoi(argv[i]);
  1149. } else if (arg == "--opt-delta") {
  1150. if (++i >= argc) {
  1151. *invalid_param = true;
  1152. return true;
  1153. }
  1154. params->opt_delta = std::stof(argv[i]);
  1155. } else if (arg == "--opt-max-no-improvement") {
  1156. if (++i >= argc) {
  1157. *invalid_param = true;
  1158. return true;
  1159. }
  1160. params->opt_max_no_improvement = std::stoi(argv[i]);
  1161. } else if (arg == "--adam-epsf") {
  1162. if (++i >= argc) {
  1163. *invalid_param = true;
  1164. return true;
  1165. }
  1166. params->adam_eps_f = std::stof(argv[i]);
  1167. } else if (arg == "--epochs") {
  1168. if (++i >= argc) {
  1169. *invalid_param = true;
  1170. return true;
  1171. }
  1172. params->n_epochs = std::stoi(argv[i]);
  1173. } else if (arg == "--adam-iter") {
  1174. if (++i >= argc) {
  1175. *invalid_param = true;
  1176. return true;
  1177. }
  1178. params->adam_n_iter = std::stoi(argv[i]);
  1179. } else if (arg == "--adam-alpha") {
  1180. if (++i >= argc) {
  1181. *invalid_param = true;
  1182. return true;
  1183. }
  1184. params->adam_alpha = std::stof(argv[i]);
  1185. } else if (arg == "--adam-min-alpha") {
  1186. if (++i >= argc) {
  1187. *invalid_param = true;
  1188. return true;
  1189. }
  1190. params->adam_min_alpha = std::stof(argv[i]);
  1191. } else if (arg == "--adam-decay") {
  1192. if (++i >= argc) {
  1193. *invalid_param = true;
  1194. return true;
  1195. }
  1196. params->adam_decay = std::stof(argv[i]);
  1197. } else if (arg == "--adam-decay-min-ndim") {
  1198. if (++i >= argc) {
  1199. *invalid_param = true;
  1200. return true;
  1201. }
  1202. params->adam_decay_min_ndim = std::stoi(argv[i]);
  1203. } else if (arg == "--adam-beta1") {
  1204. if (++i >= argc) {
  1205. *invalid_param = true;
  1206. return true;
  1207. }
  1208. params->adam_beta1 = std::stof(argv[i]);
  1209. } else if (arg == "--adam-beta2") {
  1210. if (++i >= argc) {
  1211. *invalid_param = true;
  1212. return true;
  1213. }
  1214. params->adam_beta2 = std::stof(argv[i]);
  1215. } else if (arg == "--adam-gclip") {
  1216. if (++i >= argc) {
  1217. *invalid_param = true;
  1218. return true;
  1219. }
  1220. params->adam_gclip = std::stof(argv[i]);
  1221. } else if (arg == "-ngl" || arg == "--n-gpu-layers") {
  1222. if (++i >= argc) {
  1223. *invalid_param = true;
  1224. return true;
  1225. }
  1226. if (llama_supports_gpu_offload()) {
  1227. params->n_gpu_layers = std::stoi(argv[i]);
  1228. } else {
  1229. fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
  1230. fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
  1231. }
  1232. } else if (arg == "-h" || arg == "--help") {
  1233. params->print_usage = true;
  1234. return true;
  1235. } else {
  1236. return false;
  1237. }
  1238. return true;
  1239. }
  1240. void finish_processing_train_args(struct train_params_common * params) {
  1241. if (params->escape) {
  1242. string_process_escapes(params->sample_start);
  1243. }
  1244. }
  1245. void train_opt_callback(void * vdata, int accum_step, float * sched, bool * cancel) {
  1246. struct train_opt_callback_data * data = (struct train_opt_callback_data *) vdata;
  1247. struct train_params_common * params = data->params;
  1248. struct train_state * train = data->train;
  1249. struct ggml_opt_context * opt = train->opt;
  1250. int n_batch = params->n_batch;
  1251. int n_ctx = params->n_ctx;
  1252. if (accum_step == 0) {
  1253. // time measurement
  1254. int64_t now = ggml_time_ms();
  1255. if (now > data->last_time && opt->iter > data->first_iter) {
  1256. double dt = (double) (now - data->last_time);
  1257. if (data->millis_per_iter == 0.0) {
  1258. data->millis_per_iter = dt;
  1259. } else {
  1260. const double gain = 0.7;
  1261. data->millis_per_iter = data->millis_per_iter*(1.0-gain) + dt*gain;
  1262. }
  1263. }
  1264. double remaining_millis = 0.0;
  1265. if (data->millis_per_iter > 0.0) {
  1266. const int n_iter = params->adam_n_iter;
  1267. const int done_iter = opt->iter - data->first_iter;
  1268. const int remaining_iter = n_iter - done_iter;
  1269. remaining_millis = remaining_iter * data->millis_per_iter;
  1270. }
  1271. // file saving
  1272. const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
  1273. if (save_now) {
  1274. int new_iters = opt->iter - data->last_save_iter;
  1275. train->train_its += new_iters;
  1276. train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx;
  1277. if (data->save_cb) {
  1278. data->save_cb(data->save_data, train);
  1279. }
  1280. data->last_save_iter = opt->iter;
  1281. }
  1282. // exclude file saving from time measurement, by measuring last_time after saving
  1283. data->last_time = ggml_time_ms();
  1284. *sched = learning_schedule(
  1285. opt->iter,
  1286. params->warmup,
  1287. params->cos_decay_steps,
  1288. params->adam_alpha,
  1289. params->adam_min_alpha,
  1290. params->cos_decay_min,
  1291. params->cos_decay_restart,
  1292. params->enable_restart);
  1293. int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f);
  1294. if (impr_plot > 0) impr_plot = 0;
  1295. if (std::isnan(opt->loss_before) || std::isnan(opt->loss_after)) impr_plot = 0;
  1296. printf("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f",
  1297. __func__, opt->iter, std::min(1+train->shuffle_next_sample, train->shuffle_sample_count), train->shuffle_sample_count,
  1298. *sched, opt->loss_after);
  1299. if (data->millis_per_iter > 0) {
  1300. printf(" dt=");
  1301. print_duration(data->millis_per_iter);
  1302. printf(" eta=");
  1303. print_duration(remaining_millis);
  1304. }
  1305. float improvement = opt->loss_before - opt->loss_after;
  1306. const float plot_scale = 10.0f;
  1307. int bar_len = (int)(1 + improvement*plot_scale + 0.5);
  1308. printf(" |");
  1309. for (int i=0; i<bar_len; ++i) {
  1310. printf("-");
  1311. }
  1312. printf(">");
  1313. printf("\n");
  1314. }
  1315. int64_t used_samples = get_example_targets_batch(
  1316. data->lctx,
  1317. data->tokens_input,
  1318. data->target_probs,
  1319. train->shuffle_next_sample,
  1320. data->shuffled_samples_offs,
  1321. data->shuffled_samples_begin,
  1322. data->shuffled_samples_size,
  1323. data->samples_count,
  1324. data->tokens_data,
  1325. data->tokens_size,
  1326. params->separate_with_eos,
  1327. params->separate_with_bos,
  1328. params->fill_with_next_samples,
  1329. params->sample_random_offsets);
  1330. train->train_samples += used_samples;
  1331. train->shuffle_next_sample += used_samples;
  1332. if (train->shuffle_next_sample >= train->shuffle_sample_count) {
  1333. ++train->train_epochs;
  1334. printf("%s: reshuffle samples. completed epochs: %llu\n", __func__, (long long unsigned) train->train_epochs);
  1335. // note: we may have used some samples from the current shuffling more than once
  1336. train->shuffle_rng_state_current = train->shuffle_rng_state_next;
  1337. train->shuffle_rng_state_next = shuffle_samples(
  1338. train->shuffle_rng_state_current,
  1339. data->shuffled_samples_offs,
  1340. data->shuffled_samples_begin,
  1341. data->shuffled_samples_size,
  1342. data->samples_begin,
  1343. data->samples_size,
  1344. data->samples_count);
  1345. train->shuffle_next_sample = 0;
  1346. }
  1347. const bool last_epoch_reached = (params->n_epochs > 0 && (int64_t) train->train_epochs - data->first_epoch >= params->n_epochs);
  1348. if (last_epoch_reached) {
  1349. // allow optimization iteration at last epoch to be completed before canceling
  1350. if (data->iter_at_last_epoch < 0) {
  1351. data->iter_at_last_epoch = opt->iter;
  1352. } else if (opt->iter > data->iter_at_last_epoch) {
  1353. *cancel = true;
  1354. }
  1355. }
  1356. }