1
0

train.cpp 65 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513
  1. #include "train.h"
  2. #include "common.h"
  3. #include <random>
  4. #include <sstream>
  5. #include <functional>
  6. struct random_normal_distribution {
  7. std::mt19937 gen;
  8. std::normal_distribution<float> rd;
  9. float min;
  10. float max;
  11. };
  12. struct random_uniform_distribution {
  13. std::mt19937 gen;
  14. std::uniform_real_distribution<float> rd;
  15. };
  16. struct train_state * init_train_state() {
  17. struct train_state * state = new struct train_state;
  18. state->train_its = 0;
  19. state->train_samples = 0;
  20. state->train_tokens = 0;
  21. state->train_epochs = 0;
  22. state->shuffle_samples_hash = 0;
  23. state->shuffle_sample_count = 0;
  24. state->shuffle_next_sample = 0;
  25. state->shuffle_rng_state_current = "";
  26. state->shuffle_rng_state_next = "";
  27. state->opt = new struct ggml_opt_context;
  28. state->opt->ctx = NULL;
  29. state->opt->params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM);
  30. state->opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
  31. state->opt->loss_after = 0.0f;
  32. return state;
  33. }
  34. void free_train_state(struct train_state * state) {
  35. delete state->opt;
  36. delete state;
  37. }
  38. struct random_normal_distribution * init_random_normal_distribution(
  39. int seed, float mean, float std, float min, float max
  40. ) {
  41. struct random_normal_distribution * rnd = (struct random_normal_distribution *) malloc(sizeof(struct random_normal_distribution));
  42. rnd->gen = std::mt19937(seed);
  43. rnd->rd = std::normal_distribution<float>{mean, std};
  44. rnd->min = min;
  45. rnd->max = max;
  46. return rnd;
  47. }
  48. struct random_uniform_distribution * init_random_uniform_distribution(int seed, float min, float max) {
  49. struct random_uniform_distribution * rnd = (struct random_uniform_distribution *) malloc(sizeof(struct random_uniform_distribution));
  50. rnd->gen = std::mt19937(seed);
  51. rnd->rd = std::uniform_real_distribution<float>{min, max};
  52. return rnd;
  53. }
  54. void free_random_normal_distribution (struct random_normal_distribution * rnd) {
  55. free(rnd);
  56. }
  57. void free_random_uniform_distribution(struct random_uniform_distribution * rnd) {
  58. free(rnd);
  59. }
  60. struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) {
  61. float scale = 1.0f; // xavier
  62. switch (ggml_n_dims(tensor)) {
  63. case 1:
  64. scale /= sqrtf((float) tensor->ne[0]);
  65. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  66. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
  67. *dst = scale * frand_normal(rnd);
  68. }
  69. break;
  70. case 2:
  71. scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
  72. for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
  73. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  74. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
  75. *dst = scale * frand_normal(rnd);
  76. }
  77. }
  78. break;
  79. case 3:
  80. scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
  81. for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
  82. for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
  83. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  84. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
  85. *dst = scale * frand_normal(rnd);
  86. }
  87. }
  88. }
  89. break;
  90. case 4:
  91. scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
  92. for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
  93. for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
  94. for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
  95. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  96. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
  97. *dst = scale * frand_normal(rnd);
  98. }
  99. }
  100. }
  101. }
  102. break;
  103. default:
  104. die("Unsupported tensor->n_dims");
  105. };
  106. return tensor;
  107. }
  108. struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd) {
  109. switch (ggml_n_dims(tensor)) {
  110. case 1:
  111. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  112. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
  113. *dst = frand_uniform(rnd);
  114. }
  115. break;
  116. case 2:
  117. for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
  118. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  119. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
  120. *dst = frand_uniform(rnd);
  121. }
  122. }
  123. break;
  124. case 3:
  125. for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
  126. for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
  127. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  128. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
  129. *dst = frand_uniform(rnd);
  130. }
  131. }
  132. }
  133. break;
  134. case 4:
  135. for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
  136. for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
  137. for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
  138. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  139. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
  140. *dst = frand_uniform(rnd);
  141. }
  142. }
  143. }
  144. }
  145. break;
  146. default:
  147. die("Unsupported tensor->n_dims");
  148. };
  149. return tensor;
  150. }
  151. float frand() {
  152. return (float)rand()/((float)(RAND_MAX) + 1.0f);
  153. }
  154. float frand_normal(struct random_normal_distribution * rnd) {
  155. return fclamp(rnd->rd(rnd->gen), rnd->min, rnd->max);
  156. }
  157. float frand_uniform(struct random_uniform_distribution * rnd) {
  158. return rnd->rd(rnd->gen);
  159. }
  160. int clamp(const int v, const int min, const int max) {
  161. return ((v < min) ? (min) : (v > max) ? (max) : v);
  162. }
  163. float fclamp(const float v, const float min, const float max) {
  164. return ((v < min) ? (min) : (v > max) ? (max) : v);
  165. }
  166. void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) {
  167. GGML_ASSERT(tensor->ne[0] == ne0);
  168. GGML_ASSERT(tensor->ne[1] == 1);
  169. GGML_ASSERT(tensor->ne[2] == 1);
  170. GGML_ASSERT(tensor->ne[3] == 1);
  171. }
  172. void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1) {
  173. GGML_ASSERT(tensor->ne[0] == ne0);
  174. GGML_ASSERT(tensor->ne[1] == ne1);
  175. GGML_ASSERT(tensor->ne[2] == 1);
  176. GGML_ASSERT(tensor->ne[3] == 1);
  177. }
  178. void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2) {
  179. GGML_ASSERT(tensor->ne[0] == ne0);
  180. GGML_ASSERT(tensor->ne[1] == ne1);
  181. GGML_ASSERT(tensor->ne[2] == ne2);
  182. GGML_ASSERT(tensor->ne[3] == 1);
  183. }
  184. void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
  185. GGML_ASSERT(tensor->ne[0] == ne0);
  186. GGML_ASSERT(tensor->ne[1] == ne1);
  187. GGML_ASSERT(tensor->ne[2] == ne2);
  188. GGML_ASSERT(tensor->ne[3] == ne3);
  189. }
  190. int64_t get_example_targets_batch(
  191. struct llama_context * lctx,
  192. struct ggml_tensor * tokens_input,
  193. struct ggml_tensor * target_probs,
  194. int64_t example_id,
  195. const size_t * samples_offs,
  196. const size_t * samples_begin,
  197. const size_t * samples_size,
  198. size_t samples_count,
  199. const llama_token * train_data,
  200. size_t n_train_data,
  201. bool separate_with_eos,
  202. bool separate_with_bos,
  203. bool fill_with_next_samples,
  204. bool sample_random_offsets
  205. ) {
  206. GGML_ASSERT(samples_count > 0);
  207. GGML_ASSERT(ggml_is_matrix(tokens_input));
  208. GGML_ASSERT(ggml_is_3d(target_probs));
  209. int64_t n_vocab = target_probs->ne[0];
  210. int64_t n_tokens = tokens_input->ne[0];
  211. int64_t n_batch = tokens_input->ne[1];
  212. GGML_ASSERT(n_vocab == target_probs->ne[0]);
  213. GGML_ASSERT(n_tokens == target_probs->ne[1]);
  214. GGML_ASSERT(n_batch == target_probs->ne[2]);
  215. int64_t used_samples = 0;
  216. ggml_set_f32(target_probs, 0.0f);
  217. llama_token bos = llama_token_bos(llama_get_model(lctx));
  218. llama_token eos = llama_token_eos(llama_get_model(lctx));
  219. // printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
  220. for (int k=0; k<n_batch; ++k) {
  221. // printf("%s: batch %d\n", __func__, k);
  222. size_t sample_idx = (example_id + used_samples) % samples_count;
  223. size_t sample_offs = sample_random_offsets ? samples_offs[sample_idx] : 0;
  224. size_t sample_begin = samples_begin[sample_idx];
  225. size_t sample_size = samples_size[sample_idx];
  226. ++used_samples;
  227. // printf("%s: sample_idx=%zu sample=%zu\n", __func__, sample_idx, sample);
  228. GGML_ASSERT(sample_begin+sample_size-1 < n_train_data);
  229. ggml_set_i32_nd(tokens_input, 0, k, 0, 0, bos);
  230. bool sample_separation_eos = !separate_with_eos;
  231. bool sample_separation_bos = !separate_with_bos;
  232. for (int64_t i=0; i<n_tokens; ++i) {
  233. llama_token token = eos;
  234. if (sample_offs >= sample_size && fill_with_next_samples) {
  235. if (!sample_separation_eos) {
  236. // insert eos token to separate samples
  237. sample_separation_eos = true;
  238. } else if (!sample_separation_bos) {
  239. // insert bos token to separate samples
  240. sample_separation_bos = true;
  241. token = bos;
  242. } else {
  243. // sample separation is done, continue with next sample
  244. sample_separation_eos = !separate_with_eos;
  245. sample_separation_bos = !separate_with_bos;
  246. sample_offs = 0;
  247. sample_idx = (example_id + used_samples) % samples_count;
  248. sample_begin = samples_begin[sample_idx];
  249. sample_size = samples_size[sample_idx];
  250. ++used_samples;
  251. }
  252. }
  253. // note: no else-if here
  254. if (sample_offs < sample_size) {
  255. token = clamp(train_data[sample_begin+sample_offs], 0, (llama_token) (n_vocab - 1));
  256. ++sample_offs;
  257. }
  258. ggml_set_f32_nd(target_probs, token, (int) i, (int) k, 0, +1.0f);
  259. if (i+1<n_tokens) {
  260. ggml_set_i32_nd(tokens_input, (int) (i + 1), (int) k, 0, 0, token);
  261. }
  262. }
  263. }
  264. return used_samples;
  265. }
  266. void mt19937_set_state(std::mt19937& rng, const std::string& rng_state) {
  267. std::stringstream s_rng_state;
  268. s_rng_state.imbue(std::locale::classic());
  269. s_rng_state.exceptions(std::stringstream::failbit);
  270. s_rng_state.str(rng_state);
  271. s_rng_state >> rng;
  272. }
  273. std::string mt19937_get_state(const std::mt19937& rng) {
  274. std::stringstream s_rng_state;
  275. s_rng_state.imbue(std::locale::classic());
  276. s_rng_state << rng;
  277. return s_rng_state.str();
  278. }
  279. std::string mt19937_seed_to_state(unsigned seed) {
  280. std::mt19937 rng(seed);
  281. return mt19937_get_state(rng);
  282. }
  283. std::string shuffle_samples(
  284. const std::string & rng_state,
  285. size_t * shuffled_offs,
  286. size_t * shuffled_begins,
  287. size_t * shuffled_sizes,
  288. const size_t * begins,
  289. const size_t * sizes,
  290. size_t count) {
  291. if (count == 0) return rng_state;
  292. std::mt19937 rng;
  293. mt19937_set_state(rng, rng_state);
  294. // sort indices by random value for each index
  295. std::vector<size_t> idcs;
  296. {
  297. std::vector<unsigned> rnd;
  298. idcs.resize(count);
  299. rnd.resize(count);
  300. for (unsigned i=0; i<count; ++i) {
  301. idcs[i] = i;
  302. rnd[i] = rng();
  303. }
  304. std::sort(idcs.begin(), idcs.end(), [&rnd](size_t a, size_t b){
  305. // stable sort for reproducibility
  306. return (rnd[a] == rnd[b]) ? (a < b) : (rnd[a] < rnd[b]);
  307. });
  308. }
  309. // create random offsets
  310. for (unsigned i=0; i<count; ++i) {
  311. shuffled_offs[i] = (size_t) ((sizes[idcs[i]] - 1) * ((double) rng() / (double) (rng.max()-1)));
  312. }
  313. // reorder begins and sizes by sorted indices
  314. for (unsigned i=0; i<count; ++i) {
  315. shuffled_begins[i] = begins[idcs[i]];
  316. }
  317. for (unsigned i=0; i<count; ++i) {
  318. shuffled_sizes[i] = sizes[idcs[i]];
  319. }
  320. return mt19937_get_state(rng);
  321. }
  322. size_t hash_combine(size_t h1, size_t h2) {
  323. return h1 ^ (h2 << 1);
  324. }
  325. size_t compute_samples_hash(const char* fn, const size_t* samples_begin, const size_t* samples_size, size_t sample_count) {
  326. std::hash<std::string> h_string;
  327. std::hash<unsigned long long> h_ull;
  328. size_t h = h_string(std::string(fn));
  329. h = hash_combine(h, h_ull((unsigned long long) sample_count));
  330. for (size_t i=0; i< sample_count; ++i) {
  331. h = hash_combine(h, h_ull((unsigned long long) samples_begin[i]));
  332. h = hash_combine(h, h_ull((unsigned long long) samples_size[i]));
  333. }
  334. return h;
  335. }
  336. std::string replace_str(const char * s, const char * needle, const char * replacement) {
  337. std::string str = s;
  338. size_t pos = str.find(needle);
  339. if (pos != std::string::npos) {
  340. str.replace(pos, strlen(needle), replacement);
  341. }
  342. return str;
  343. }
  344. void print_duration(double fmillis) {
  345. if (fmillis < 1000.0f) {
  346. printf("%.1fms", (float) fmillis);
  347. return;
  348. }
  349. const int64_t one_sec = 1000;
  350. const int64_t one_min = one_sec * 60;
  351. const int64_t one_hour = one_min * 60;
  352. const int64_t one_day = one_hour * 24;
  353. int64_t millis = (int64_t) fmillis;
  354. int64_t days = millis/one_day;
  355. int64_t hours = (millis - days*one_day)/one_hour;
  356. int64_t minutes = (millis - days*one_day - hours*one_hour)/one_min;
  357. int64_t seconds = (millis - days*one_day - hours*one_hour - minutes*one_min)/one_sec;
  358. // to print int64_t either cast to (long long int) or use macro PRId64 from <inttypes.h>
  359. if (days > 0) {
  360. printf("%lldd ", (long long int) days);
  361. }
  362. printf("%02lld:%02lld:%02lld", (long long int) hours, (long long int) minutes, (long long int) seconds);
  363. }
  364. float cosine_decay(int64_t step, int64_t decay_steps, float minimum) {
  365. if (step > decay_steps) {
  366. step = decay_steps;
  367. }
  368. const float cosine_decay = 0.50f*(1.0f + cosf(3.14159265359f*step/decay_steps));
  369. const float decay = (1 - minimum)*cosine_decay + minimum;
  370. return decay;
  371. }
  372. float cosine_decay_restart(int64_t step, int64_t decay_steps, float minimum, float restart_step_mult) {
  373. while (step > decay_steps) {
  374. step -= decay_steps;
  375. decay_steps = (int64_t) (restart_step_mult * decay_steps);
  376. }
  377. return cosine_decay(step, decay_steps, minimum);
  378. }
  379. float learning_schedule(
  380. int64_t step,
  381. int64_t warmup_steps,
  382. int64_t cos_decay_steps,
  383. float learning_rate,
  384. float overall_minimum,
  385. float cos_decay_minimum,
  386. float cos_decay_restart_step_mult,
  387. bool enable_restart) {
  388. float result =
  389. (step < warmup_steps)
  390. ? (float) step / (float) warmup_steps
  391. : enable_restart
  392. ? cosine_decay_restart(
  393. step - warmup_steps,
  394. cos_decay_steps,
  395. cos_decay_minimum,
  396. cos_decay_restart_step_mult)
  397. : cosine_decay(
  398. step,
  399. cos_decay_steps,
  400. cos_decay_minimum);
  401. float min = overall_minimum / learning_rate;
  402. result = min + result * (1.0f - min);
  403. return result;
  404. }
  405. static bool are_same_layout(struct ggml_tensor * a, struct ggml_tensor * b) {
  406. GGML_ASSERT(a != NULL);
  407. GGML_ASSERT(b != NULL);
  408. GGML_ASSERT(a->type == b->type);
  409. GGML_ASSERT(ggml_are_same_shape(a, b));
  410. GGML_ASSERT(ggml_is_contiguous(a) && ggml_is_contiguous(b));
  411. return true;
  412. }
  413. void copy_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name) {
  414. if (dst == NULL) {
  415. return;
  416. }
  417. struct ggml_tensor * t = ggml_get_tensor(ctx, name);
  418. GGML_ASSERT(are_same_layout(dst, t));
  419. memcpy(dst->data, t->data, ggml_nbytes(t));
  420. if (strlen(ggml_get_name(dst)) == 0) {
  421. ggml_set_name(dst, name);
  422. }
  423. }
  424. // gguf constants
  425. static const char * LLM_KV_OPTIMIZER_TYPE = "optimizer.type";
  426. static const char * LLM_KV_OPTIMIZER_TYPE_ADAM = "adam";
  427. static const char * LLM_KV_OPTIMIZER_TYPE_LBFGS = "lbfgs";
  428. static const char * LLM_KV_OPTIMIZER_FILE_VERSION = "optimizer.file_version";
  429. static const char * LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT = "optimizer.convergence_past_count";
  430. static const char * LLM_KV_OPTIMIZER_PARAMETER_COUNT = "optimizer.parameter_count";
  431. static const char * LLM_KV_OPTIMIZER_ITERATION_COUNT = "optimizer.iteration_count";
  432. static const char * LLM_KV_OPTIMIZER_JUST_INITIALIZED = "optimizer.just_initialized";
  433. static const char * LLM_KV_OPTIMIZER_ADAM_BEST_LOSS = "optimizer.adam.best_loss";
  434. static const char * LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS = "optimizer.adam.previous_loss";
  435. static const char * LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT = "optimizer.adam.no_improvement_count";
  436. static const char * LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT = "optimizer.lbfgs.approx_hessian_count";
  437. static const char * LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS = "optimizer.lbfgs.best_loss";
  438. static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP = "optimizer.lbfgs.line_search_step";
  439. static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J = "optimizer.lbfgs.line_search_j";
  440. static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K = "optimizer.lbfgs.line_search_k";
  441. static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END = "optimizer.lbfgs.line_search_end";
  442. static const char * LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT = "optimizer.lbfgs.no_improvement_count";
  443. static const char * LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS = "optimizer.adam.first_moments";
  444. static const char * LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS = "optimizer.adam.second_moments";
  445. static const char * LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES = "optimizer.adam.past_loss_values";
  446. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS = "optimizer.lbfgs.current_parameters";
  447. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS = "optimizer.lbfgs.previous_parameters";
  448. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS = "optimizer.lbfgs.current_gradients";
  449. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS = "optimizer.lbfgs.previous_gradients";
  450. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION = "optimizer.lbfgs.search_direction";
  451. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES = "optimizer.lbfgs.past_loss_values";
  452. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA = "optimizer.lbfgs.memory_alpha";
  453. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.lbfgs.memory_ys";
  454. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s";
  455. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y";
  456. static const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version";
  457. static const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count";
  458. static const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count";
  459. static const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count";
  460. static const char * LLM_KV_TRAINING_EPOCH_COUNT = "training.epoch_count";
  461. static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH = "training.shuffle.samples_hash";
  462. static const char * LLM_KV_TRAINING_SHUFFLE_RNG_STATE = "training.shuffle.rng_state";
  463. static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT = "training.shuffle.sample_count";
  464. static const char * LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE = "training.shuffle.next_sample";
  465. #define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
  466. { \
  467. const std::string skey(key); \
  468. const int kid = gguf_find_key(ctx, skey.c_str()); \
  469. if (kid >= 0) { \
  470. enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
  471. if (ktype != (type)) { \
  472. die_fmt("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype)); \
  473. } \
  474. (dst) = func(ctx, kid); \
  475. } else if (req) { \
  476. die_fmt("key not found in model: %s", skey.c_str()); \
  477. } \
  478. }
  479. void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt) {
  480. // NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
  481. uint32_t file_version;
  482. GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_FILE_VERSION);
  483. GGML_ASSERT(file_version == 0);
  484. GGUF_GET_KEY(fctx, opt->params.past, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT);
  485. GGUF_GET_KEY(fctx, opt->iter, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ITERATION_COUNT);
  486. GGUF_GET_KEY(fctx, opt->just_initialized, gguf_get_val_bool, GGUF_TYPE_BOOL, true, LLM_KV_OPTIMIZER_JUST_INITIALIZED);
  487. uint64_t nx;
  488. GGUF_GET_KEY(fctx, nx, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_OPTIMIZER_PARAMETER_COUNT);
  489. opt->nx = (size_t) nx;
  490. // don't call ggml_opt_init until optimizer type and optimizer specific parameters are know
  491. std::string opt_type;
  492. GGUF_GET_KEY(fctx, opt_type, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_OPTIMIZER_TYPE);
  493. if (opt_type == LLM_KV_OPTIMIZER_TYPE_ADAM) {
  494. opt->params.type = GGML_OPT_TYPE_ADAM;
  495. GGUF_GET_KEY(fctx, opt->adam.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS);
  496. GGUF_GET_KEY(fctx, opt->adam.fx_prev, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS);
  497. GGUF_GET_KEY(fctx, opt->adam.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT);
  498. ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
  499. copy_tensor_by_name(opt->adam.m, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
  500. copy_tensor_by_name(opt->adam.v, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
  501. copy_tensor_by_name(opt->adam.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
  502. } else if (opt_type == LLM_KV_OPTIMIZER_TYPE_LBFGS) {
  503. opt->params.type = GGML_OPT_TYPE_LBFGS;
  504. GGUF_GET_KEY(fctx, opt->params.lbfgs.m, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT);
  505. GGUF_GET_KEY(fctx, opt->lbfgs.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS);
  506. GGUF_GET_KEY(fctx, opt->lbfgs.step, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP);
  507. GGUF_GET_KEY(fctx, opt->lbfgs.j, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J);
  508. GGUF_GET_KEY(fctx, opt->lbfgs.k, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K);
  509. GGUF_GET_KEY(fctx, opt->lbfgs.end, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END);
  510. GGUF_GET_KEY(fctx, opt->lbfgs.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT);
  511. ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
  512. copy_tensor_by_name(opt->lbfgs.x, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
  513. copy_tensor_by_name(opt->lbfgs.xp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
  514. copy_tensor_by_name(opt->lbfgs.g, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
  515. copy_tensor_by_name(opt->lbfgs.gp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
  516. copy_tensor_by_name(opt->lbfgs.d, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
  517. copy_tensor_by_name(opt->lbfgs.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
  518. copy_tensor_by_name(opt->lbfgs.lmal, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
  519. copy_tensor_by_name(opt->lbfgs.lmys, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
  520. copy_tensor_by_name(opt->lbfgs.lms, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
  521. copy_tensor_by_name(opt->lbfgs.lmy, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
  522. } else {
  523. die("unknown optimizer type\n");
  524. }
  525. }
  526. void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt) {
  527. gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_FILE_VERSION, 0);
  528. gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT, opt->params.past);
  529. gguf_set_val_u64(fctx, LLM_KV_OPTIMIZER_PARAMETER_COUNT, (uint64_t) opt->nx);
  530. gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ITERATION_COUNT, opt->iter);
  531. gguf_set_val_bool(fctx, LLM_KV_OPTIMIZER_JUST_INITIALIZED, opt->just_initialized);
  532. switch (opt->params.type) {
  533. case GGML_OPT_TYPE_ADAM:
  534. {
  535. gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_ADAM);
  536. gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS, opt->adam.fx_best);
  537. gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS, opt->adam.fx_prev);
  538. gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT, opt->adam.n_no_improvement);
  539. ggml_set_name(opt->adam.m, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
  540. ggml_set_name(opt->adam.v, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
  541. if (opt->adam.pf) {
  542. ggml_set_name(opt->adam.pf, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
  543. }
  544. gguf_add_tensor(fctx, opt->adam.m);
  545. gguf_add_tensor(fctx, opt->adam.v);
  546. if (opt->adam.pf) {
  547. gguf_add_tensor(fctx, opt->adam.pf);
  548. }
  549. } break;
  550. case GGML_OPT_TYPE_LBFGS:
  551. {
  552. gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS);
  553. gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, opt->params.lbfgs.m);
  554. gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS, opt->lbfgs.fx_best);
  555. gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP, opt->lbfgs.step);
  556. gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J, opt->lbfgs.j);
  557. gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K, opt->lbfgs.k);
  558. gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END, opt->lbfgs.end);
  559. gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT, opt->lbfgs.n_no_improvement);
  560. ggml_set_name(opt->lbfgs.x, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
  561. ggml_set_name(opt->lbfgs.xp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
  562. ggml_set_name(opt->lbfgs.g, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
  563. ggml_set_name(opt->lbfgs.gp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
  564. ggml_set_name(opt->lbfgs.d, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
  565. if (opt->lbfgs.pf) {
  566. ggml_set_name(opt->lbfgs.pf, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
  567. }
  568. ggml_set_name(opt->lbfgs.lmal, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
  569. ggml_set_name(opt->lbfgs.lmys, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
  570. ggml_set_name(opt->lbfgs.lms, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
  571. ggml_set_name(opt->lbfgs.lmy, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
  572. gguf_add_tensor(fctx, opt->lbfgs.x);
  573. gguf_add_tensor(fctx, opt->lbfgs.xp);
  574. gguf_add_tensor(fctx, opt->lbfgs.g);
  575. gguf_add_tensor(fctx, opt->lbfgs.gp);
  576. gguf_add_tensor(fctx, opt->lbfgs.d);
  577. if (opt->lbfgs.pf) {
  578. gguf_add_tensor(fctx, opt->lbfgs.pf);
  579. }
  580. gguf_add_tensor(fctx, opt->lbfgs.lmal);
  581. gguf_add_tensor(fctx, opt->lbfgs.lmys);
  582. gguf_add_tensor(fctx, opt->lbfgs.lms);
  583. gguf_add_tensor(fctx, opt->lbfgs.lmy);
  584. } break;
  585. }
  586. }
  587. bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train) {
  588. if (gguf_find_key(fctx, LLM_KV_TRAINING_FILE_VERSION) < 0) {
  589. return false;
  590. }
  591. uint32_t file_version;
  592. GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION);
  593. GGML_ASSERT(file_version <= 1);
  594. if (file_version == 0) {
  595. GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT);
  596. GGUF_GET_KEY(fctx, train->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT);
  597. GGUF_GET_KEY(fctx, train->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT);
  598. } else if (file_version == 1) {
  599. GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_ITERATION_COUNT);
  600. GGUF_GET_KEY(fctx, train->train_samples, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_SAMPLE_COUNT);
  601. GGUF_GET_KEY(fctx, train->train_tokens, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_TOKEN_COUNT);
  602. GGUF_GET_KEY(fctx, train->train_epochs, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_EPOCH_COUNT);
  603. GGUF_GET_KEY(fctx, train->shuffle_samples_hash, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH);
  604. GGUF_GET_KEY(fctx, train->shuffle_rng_state_current, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_SHUFFLE_RNG_STATE);
  605. GGUF_GET_KEY(fctx, train->shuffle_sample_count, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT);
  606. GGUF_GET_KEY(fctx, train->shuffle_next_sample, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE);
  607. }
  608. load_opt_context_gguf(fctx, f_ggml_ctx, train->opt);
  609. return true;
  610. }
  611. void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train) {
  612. gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 1);
  613. gguf_set_val_u64(fctx, LLM_KV_TRAINING_ITERATION_COUNT, train->train_its);
  614. gguf_set_val_u64(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, train->train_samples);
  615. gguf_set_val_u64(fctx, LLM_KV_TRAINING_TOKEN_COUNT, train->train_tokens);
  616. gguf_set_val_u64(fctx, LLM_KV_TRAINING_EPOCH_COUNT, train->train_epochs);
  617. gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH, (uint64_t) train->shuffle_samples_hash);
  618. gguf_set_val_str(fctx, LLM_KV_TRAINING_SHUFFLE_RNG_STATE, train->shuffle_rng_state_current.c_str());
  619. gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT, (uint64_t) train->shuffle_sample_count);
  620. gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE, (uint64_t) train->shuffle_next_sample);
  621. save_opt_context_gguf(fctx, train->opt);
  622. }
  623. struct llama_file {
  624. // use FILE * so we don't have to re-open the file to mmap
  625. FILE * fp;
  626. size_t size;
  627. llama_file(const char * fname, const char * mode) {
  628. fp = std::fopen(fname, mode);
  629. if (fp == NULL) {
  630. size = 0;
  631. } else {
  632. seek(0, SEEK_END);
  633. size = tell();
  634. seek(0, SEEK_SET);
  635. }
  636. }
  637. size_t tell() const {
  638. #ifdef _WIN32
  639. __int64 ret = _ftelli64(fp);
  640. #else
  641. long ret = std::ftell(fp);
  642. #endif
  643. GGML_ASSERT(ret != -1); // this really shouldn't fail
  644. return (size_t) ret;
  645. }
  646. void seek(size_t offset, int whence) {
  647. #ifdef _WIN32
  648. int ret = _fseeki64(fp, (__int64) offset, whence);
  649. #else
  650. int ret = std::fseek(fp, (long) offset, whence);
  651. #endif
  652. GGML_ASSERT(ret == 0); // same
  653. }
  654. void read_raw(void * ptr, size_t size) {
  655. if (size == 0) {
  656. return;
  657. }
  658. errno = 0;
  659. std::size_t ret = std::fread(ptr, size, 1, fp);
  660. if (ferror(fp)) {
  661. die_fmt("read error: %s", strerror(errno));
  662. }
  663. if (ret != 1) {
  664. die("unexpectedly reached end of file");
  665. }
  666. }
  667. std::uint32_t read_u32() {
  668. std::uint32_t ret;
  669. read_raw(&ret, sizeof(ret));
  670. return ret;
  671. }
  672. std::string read_string(std::uint32_t len) {
  673. std::vector<char> chars(len);
  674. read_raw(chars.data(), len);
  675. return std::string(chars.data(), len);
  676. }
  677. void write_raw(const void * ptr, size_t size) {
  678. if (size == 0) {
  679. return;
  680. }
  681. errno = 0;
  682. size_t ret = std::fwrite(ptr, size, 1, fp);
  683. if (ret != 1) {
  684. die_fmt("write error: %s", strerror(errno));
  685. }
  686. }
  687. void write_u32(std::uint32_t val) {
  688. write_raw(&val, sizeof(val));
  689. }
  690. ~llama_file() {
  691. if (fp) {
  692. std::fclose(fp);
  693. }
  694. }
  695. };
  696. static size_t utf8_len(char src) {
  697. const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
  698. uint8_t highbits = static_cast<uint8_t>(src) >> 4;
  699. return lookup[highbits];
  700. }
  701. // mark each byte with its utf8 unit number.
  702. // returns the number of utf8 characters.
  703. // e.g. when bytes == '\x61\xD0\xB0\x62',
  704. // then utf8_units will become [0,0,1,0]
  705. // utf8_nunits will become [1,2,2,1] and 3 is returned.
  706. // bytes where utf8_units is zero, are the begin of an utf8 character.
  707. static size_t mark_utf8_units(const char* bytes, int * utf8_units, int * utf8_nunits, size_t count) {
  708. size_t offs = 0;
  709. size_t count_utf8 = 0;
  710. while(offs < count) {
  711. int len = (int) utf8_len(bytes[offs]);
  712. for (int i=0; i<len; ++i) {
  713. utf8_units[offs+i] = i;
  714. utf8_nunits[offs+i] = len;
  715. }
  716. offs += len;
  717. ++count_utf8;
  718. }
  719. return count_utf8;
  720. }
  721. size_t tokenize_file(
  722. struct llama_context * lctx,
  723. const char * filename,
  724. const std::string & sample_start,
  725. bool include_sample_start,
  726. bool overlapping_samples,
  727. unsigned context_length,
  728. std::vector<llama_token> & out_tokens,
  729. std::vector<size_t> & out_samples_begin,
  730. std::vector<size_t> & out_samples_size) {
  731. struct llama_file f(filename, "rb");
  732. if (f.size == 0) {
  733. out_tokens.clear();
  734. out_samples_begin.clear();
  735. out_samples_size.clear();
  736. printf("%s: warning: empty or not existing training data file '%s'\n",
  737. __func__, filename);
  738. return out_tokens.size();
  739. }
  740. // account for possible leading whitespace that will be added by tokenizer
  741. // e.g. '\t' will be tokenized by llama spm tokenizer to [29871, 12]
  742. const int n_max_tokens_overhead = 1;
  743. std::vector<char> buf;
  744. buf.resize(f.size);
  745. f.read_raw(buf.data(), f.size);
  746. std::vector<int> utf8_units;
  747. std::vector<int> utf8_nunits;
  748. utf8_units.resize(buf.size());
  749. utf8_nunits.resize(buf.size());
  750. mark_utf8_units(buf.data(), utf8_units.data(), utf8_nunits.data(), buf.size());
  751. if (sample_start.size() == 0) {
  752. // tokenize all data at once
  753. out_tokens.resize(buf.size() + n_max_tokens_overhead);
  754. int n_tokens = llama_tokenize(
  755. llama_get_model(lctx),
  756. buf.data(),
  757. (int) buf.size(),
  758. out_tokens.data(),
  759. (int) out_tokens.size(),
  760. false, false);
  761. if (n_tokens < 0) {
  762. out_tokens.resize(-n_tokens);
  763. n_tokens = llama_tokenize(
  764. llama_get_model(lctx),
  765. buf.data(),
  766. (int) buf.size(),
  767. out_tokens.data(),
  768. (int) out_tokens.size(),
  769. false, false);
  770. }
  771. if (n_tokens >= 0) {
  772. out_tokens.resize(n_tokens);
  773. }
  774. // generate sample starts at all token positions
  775. out_samples_begin.clear();
  776. out_samples_begin.push_back(0);
  777. out_samples_size.push_back(std::min((size_t) context_length, out_tokens.size()));
  778. size_t end = (out_tokens.size() >= context_length) ? (out_tokens.size() - context_length) : 0;
  779. for (size_t sample_begin = 1; sample_begin < end; ++sample_begin) {
  780. out_samples_begin.push_back(sample_begin);
  781. out_samples_size.push_back(context_length);
  782. }
  783. } else {
  784. // split data into samples and tokenize each sample
  785. std::string data_str(buf.data(), buf.size());
  786. out_samples_begin.clear();
  787. out_samples_size.clear();
  788. out_tokens.clear();
  789. // find all positions of pattern sample_start
  790. size_t sample_begin = data_str.find(sample_start, 0);
  791. while (sample_begin != std::string::npos) {
  792. out_samples_begin.push_back(sample_begin);
  793. const size_t search_start = sample_begin + sample_start.size();
  794. sample_begin = data_str.find(sample_start, search_start);
  795. }
  796. if (out_samples_begin.size() == 0) {
  797. printf("%s: warning: sample start pattern '%s' not found. inserting single sample at data begin\n",
  798. __func__, sample_start.c_str());
  799. out_samples_begin.push_back(0);
  800. }
  801. out_samples_size.resize(out_samples_begin.size(), 0);
  802. std::vector<char> buf_sample;
  803. std::vector<llama_token> tok_sample;
  804. const size_t sample_begin_offset = (include_sample_start ? 0 : sample_start.size());
  805. size_t found_too_big_sample = 0;
  806. size_t found_too_small_sample = 0;
  807. size_t found_empty_sample = 0;
  808. size_t found_min_sample_size = SIZE_MAX;
  809. size_t found_max_sample_size = 0;
  810. size_t max_token_text_size = 0;
  811. int n_vocab = llama_n_vocab(llama_get_model(lctx));
  812. for (llama_token token=0; token < n_vocab; ++token) {
  813. max_token_text_size = std::max(
  814. max_token_text_size,
  815. strlen(llama_token_get_text(llama_get_model(lctx), token)));
  816. }
  817. // upper bound of context byte length.
  818. // strings with this byte length should always tokenize to at least context_length tokens.
  819. size_t context_byte_len = max_token_text_size*context_length;
  820. for (unsigned i=0; i<out_samples_begin.size(); ++i) {
  821. // determine sample begin and end from pattern positions
  822. size_t sample_begin = out_samples_begin[i] + sample_begin_offset;
  823. size_t sample_end = overlapping_samples
  824. ? std::min(
  825. data_str.size(),
  826. sample_begin + context_byte_len)
  827. : (i+1 < out_samples_begin.size()
  828. ? out_samples_begin[i+1]
  829. : data_str.size());
  830. if (sample_end < utf8_units.size() && utf8_units[sample_end] > 0) {
  831. // sample end is in the middle of an utf8 character.
  832. // advance sample_end to the begin of the next utf8 character.
  833. sample_end += utf8_nunits[sample_end] - utf8_units[sample_end];
  834. }
  835. size_t sample_size = sample_end - sample_begin;
  836. if (sample_size == 0) {
  837. ++found_empty_sample;
  838. }
  839. if (sample_size > 0) {
  840. // llama_tokenize expects zero terminated string,
  841. // copy sample into buffer and zero terminate it.
  842. buf_sample.resize(sample_size);
  843. memcpy(buf_sample.data(), data_str.data() + sample_begin, sample_size);
  844. // printf("sample: '%s'\n", buf_sample.data());
  845. // tokenize the sample
  846. tok_sample.resize(buf_sample.size() + n_max_tokens_overhead);
  847. int n_tokens = llama_tokenize(llama_get_model(lctx),
  848. buf_sample.data(),
  849. (int) buf_sample.size(),
  850. tok_sample.data(),
  851. (int) tok_sample.size(),
  852. false, false);
  853. if (n_tokens < 0) {
  854. tok_sample.resize(-n_tokens);
  855. n_tokens = llama_tokenize(llama_get_model(lctx),
  856. buf_sample.data(),
  857. (int) buf_sample.size(),
  858. tok_sample.data(),
  859. (int) tok_sample.size(),
  860. false, false);
  861. GGML_ASSERT(n_tokens >= 0);
  862. }
  863. GGML_ASSERT(n_tokens <= (int) tok_sample.size());
  864. if ((size_t) n_tokens > context_length) {
  865. ++found_too_big_sample;
  866. } else if ((size_t) n_tokens < context_length) {
  867. ++found_too_small_sample;
  868. }
  869. found_max_sample_size = std::max(found_max_sample_size, (size_t) n_tokens);
  870. found_min_sample_size = std::min(found_min_sample_size, (size_t) n_tokens);
  871. // write out tokens, start and size of sample
  872. // overwrite the string start position with the token start position
  873. out_samples_begin[i] = out_tokens.size();
  874. out_samples_size[i] = (size_t) n_tokens;
  875. out_tokens.insert(out_tokens.end(), tok_sample.begin(), tok_sample.begin() + n_tokens);
  876. } else {
  877. out_samples_begin[i] = out_tokens.size();
  878. out_samples_size[i] = 0;
  879. }
  880. }
  881. if (found_too_big_sample > 0) {
  882. printf("%s: warning: found %zu samples (max length %zu) that exceed context length of %u. samples will be cut off.\n",
  883. __func__, found_too_big_sample, found_max_sample_size, context_length);
  884. }
  885. if (found_too_small_sample > 0) {
  886. printf("%s: warning: found %zu samples (min length %zu) that are shorter than context length of %u.\n",
  887. __func__, found_too_small_sample, found_min_sample_size, context_length);
  888. }
  889. if (found_empty_sample) {
  890. printf("%s: warning: found %zu empty samples.\n",
  891. __func__, found_empty_sample);
  892. }
  893. }
  894. printf("%s: total number of samples: %zu\n",
  895. __func__, out_samples_begin.size());
  896. GGML_ASSERT(out_samples_begin.size() == out_samples_size.size());
  897. return out_tokens.size();
  898. }
  899. std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration) {
  900. std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
  901. return replace_str(filename, pattern_it, sit.c_str());
  902. }
  903. struct train_params_common get_default_train_params_common() {
  904. struct train_params_common params;
  905. params.fn_train_data = "shakespeare.txt";
  906. params.fn_checkpoint_in = "checkpoint.gguf";
  907. params.fn_checkpoint_out = "checkpoint-ITERATION.gguf";
  908. params.pattern_fn_it = "ITERATION";
  909. params.fn_latest = "LATEST";
  910. params.print_usage = false;
  911. params.save_every = 10;
  912. params.seed = -1;
  913. params.n_ctx = 128;
  914. params.n_threads = 6;
  915. params.n_batch = 8;
  916. params.n_gradient_accumulation = 1;
  917. params.n_epochs = -1;
  918. params.n_gpu_layers = 0;
  919. params.custom_n_ctx = false;
  920. params.use_flash = true;
  921. params.use_checkpointing = true;
  922. params.sample_start = "";
  923. params.include_sample_start = false;
  924. params.escape = false;
  925. params.overlapping_samples = false;
  926. params.fill_with_next_samples = false;
  927. params.separate_with_eos = false;
  928. params.separate_with_bos = true;
  929. params.sample_random_offsets = false;
  930. params.force_reshuffle = false;
  931. params.opt_past = 0;
  932. params.opt_delta = 1e-5f;
  933. params.opt_max_no_improvement = 0;
  934. params.warmup = 100;
  935. params.cos_decay_steps = 1000;
  936. params.cos_decay_restart = 1.1f;
  937. params.cos_decay_min = 0.1f;
  938. params.enable_restart = false;
  939. params.adam_n_iter = 256;
  940. params.adam_alpha = 1e-3f;
  941. params.adam_min_alpha = 0;
  942. params.adam_decay = 1e-1f;
  943. params.adam_decay_min_ndim = 2;
  944. params.adam_beta1 = 0.9f;
  945. params.adam_beta2 = 0.999f;
  946. params.adam_gclip = 1.0f;
  947. params.adam_eps_f = 0.0f;
  948. return params;
  949. }
  950. void print_common_train_usage(int /*argc*/, char ** /*argv*/, const struct train_params_common * params) {
  951. // fprintf(stderr, "usage: %s [options]\n", argv[0]);
  952. // fprintf(stderr, "\n");
  953. // fprintf(stderr, "options:\n");
  954. // fprintf(stderr, " -h, --help show this help message and exit\n");
  955. fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data);
  956. fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
  957. fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
  958. fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it);
  959. fprintf(stderr, " --fn-latest STR string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest);
  960. fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every);
  961. fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n");
  962. fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
  963. fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
  964. fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
  965. fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation);
  966. fprintf(stderr, " --sample-start STR Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str());
  967. fprintf(stderr, " --include-sample-start Include the sample start in the samples. (default off)\n");
  968. fprintf(stderr, " --escape process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
  969. fprintf(stderr, " --overlapping-samples Samples may overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n");
  970. fprintf(stderr, " --fill-with-next-samples Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n");
  971. fprintf(stderr, " --separate-with-eos When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : "");
  972. fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : "");
  973. fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : "");
  974. fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : "");
  975. fprintf(stderr, " --sample-random-offsets Use samples beginning at random offsets. Together with fill-with-next-samples this may help for training endless text generation.%s\n", params->sample_random_offsets ? " (default)" : "");
  976. fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n");
  977. fprintf(stderr, " --no-flash Don't use flash attention \n");
  978. fprintf(stderr, " --use-flash Use flash attention (default)\n");
  979. fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
  980. fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n");
  981. fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup);
  982. fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
  983. fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
  984. fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min);
  985. fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : "");
  986. fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : "");
  987. fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
  988. fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
  989. fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
  990. fprintf(stderr, " --epochs N Maximum number epochs to process. (default %d)\n", params->n_epochs);
  991. fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
  992. fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
  993. fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha);
  994. fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
  995. fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim);
  996. fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
  997. fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
  998. fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
  999. fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
  1000. fprintf(stderr, " -ngl N, --n-gpu-layers N Number of model layers to offload to GPU (default %d)", params->n_gpu_layers);
  1001. fprintf(stderr, "\n");
  1002. }
  1003. bool consume_common_train_arg(
  1004. int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param
  1005. ) {
  1006. int& i = *idx;
  1007. std::string arg = argv[i];
  1008. const std::string arg_prefix = "--";
  1009. if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
  1010. std::replace(arg.begin(), arg.end(), '_', '-');
  1011. }
  1012. if (arg == "--train-data") {
  1013. if (++i >= argc) {
  1014. *invalid_param = true;
  1015. return true;
  1016. }
  1017. params->fn_train_data = argv[i];
  1018. } else if (arg == "--checkpoint-in") {
  1019. if (++i >= argc) {
  1020. *invalid_param = true;
  1021. return true;
  1022. }
  1023. params->fn_checkpoint_in = argv[i];
  1024. } else if (arg == "--checkpoint-out") {
  1025. if (++i >= argc) {
  1026. *invalid_param = true;
  1027. return true;
  1028. }
  1029. params->fn_checkpoint_out = argv[i];
  1030. } else if (arg == "--pattern-fn-it") {
  1031. if (++i >= argc) {
  1032. *invalid_param = true;
  1033. return true;
  1034. }
  1035. params->pattern_fn_it = argv[i];
  1036. } else if (arg == "--fn-latest") {
  1037. if (++i >= argc) {
  1038. *invalid_param = true;
  1039. return true;
  1040. }
  1041. params->fn_latest = argv[i];
  1042. } else if (arg == "--save-every") {
  1043. if (++i >= argc) {
  1044. *invalid_param = true;
  1045. return true;
  1046. }
  1047. params->save_every = std::stoi(argv[i]);
  1048. } else if (arg == "-s" || arg == "--seed") {
  1049. if (++i >= argc) {
  1050. *invalid_param = true;
  1051. return true;
  1052. }
  1053. params->seed = std::stoi(argv[i]);
  1054. } else if (arg == "-c" || arg == "--ctx") {
  1055. if (++i >= argc) {
  1056. *invalid_param = true;
  1057. return true;
  1058. }
  1059. params->n_ctx = std::stoi(argv[i]);
  1060. params->custom_n_ctx = true;
  1061. } else if (arg == "-t" || arg == "--threads") {
  1062. if (++i >= argc) {
  1063. *invalid_param = true;
  1064. return true;
  1065. }
  1066. params->n_threads = std::stoi(argv[i]);
  1067. } else if (arg == "-b" || arg == "--batch") {
  1068. if (++i >= argc) {
  1069. *invalid_param = true;
  1070. return true;
  1071. }
  1072. params->n_batch = std::stoi(argv[i]);
  1073. } else if (arg == "--grad-acc") {
  1074. if (++i >= argc) {
  1075. *invalid_param = true;
  1076. return true;
  1077. }
  1078. params->n_gradient_accumulation = std::max(1, std::stoi(argv[i]));
  1079. } else if (arg == "--sample-start") {
  1080. if (++i >= argc) {
  1081. *invalid_param = true;
  1082. return true;
  1083. }
  1084. params->sample_start = std::string(argv[i]);
  1085. } else if (arg == "--escape") {
  1086. params->escape = true;
  1087. } else if (arg == "--include-sample-start") {
  1088. params->include_sample_start = true;
  1089. } else if (arg == "--overlapping-samples") {
  1090. params->overlapping_samples = true;
  1091. } else if (arg == "--fill-with-next-samples") {
  1092. params->fill_with_next_samples = true;
  1093. } else if (arg == "--separate-with-eos") {
  1094. params->separate_with_eos = true;
  1095. } else if (arg == "--separate-with-bos") {
  1096. params->separate_with_bos = true;
  1097. } else if (arg == "--no-separate-with-eos") {
  1098. params->separate_with_eos = false;
  1099. } else if (arg == "--no-separate-with-bos") {
  1100. params->separate_with_bos = false;
  1101. } else if (arg == "--sample-random-offsets") {
  1102. params->sample_random_offsets = true;
  1103. } else if (arg == "--force-reshuffle") {
  1104. params->force_reshuffle = true;
  1105. } else if (arg == "--no-flash") {
  1106. params->use_flash = false;
  1107. } else if (arg == "--use-flash") {
  1108. params->use_flash = true;
  1109. } else if (arg == "--no-checkpointing") {
  1110. params->use_checkpointing = false;
  1111. } else if (arg == "--use-checkpointing") {
  1112. params->use_checkpointing = true;
  1113. } else if (arg == "--warmup") {
  1114. if (++i >= argc) {
  1115. *invalid_param = true;
  1116. return true;
  1117. }
  1118. params->warmup = std::stoi(argv[i]);
  1119. } else if (arg == "--cos-decay-steps") {
  1120. if (++i >= argc) {
  1121. *invalid_param = true;
  1122. return true;
  1123. }
  1124. params->cos_decay_steps = std::stoi(argv[i]);
  1125. } else if (arg == "--cos-decay-restart") {
  1126. if (++i >= argc) {
  1127. *invalid_param = true;
  1128. return true;
  1129. }
  1130. params->cos_decay_restart = std::stof(argv[i]);
  1131. } else if (arg == "--cos-decay-min") {
  1132. if (++i >= argc) {
  1133. *invalid_param = true;
  1134. return true;
  1135. }
  1136. params->cos_decay_min = std::stof(argv[i]);
  1137. } else if (arg == "--enable-restart") {
  1138. params->enable_restart = true;
  1139. } else if (arg == "--disable-restart") {
  1140. params->enable_restart = false;
  1141. } else if (arg == "--opt-past") {
  1142. if (++i >= argc) {
  1143. *invalid_param = true;
  1144. return true;
  1145. }
  1146. params->opt_past = std::stoi(argv[i]);
  1147. } else if (arg == "--opt-delta") {
  1148. if (++i >= argc) {
  1149. *invalid_param = true;
  1150. return true;
  1151. }
  1152. params->opt_delta = std::stof(argv[i]);
  1153. } else if (arg == "--opt-max-no-improvement") {
  1154. if (++i >= argc) {
  1155. *invalid_param = true;
  1156. return true;
  1157. }
  1158. params->opt_max_no_improvement = std::stoi(argv[i]);
  1159. } else if (arg == "--adam-epsf") {
  1160. if (++i >= argc) {
  1161. *invalid_param = true;
  1162. return true;
  1163. }
  1164. params->adam_eps_f = std::stof(argv[i]);
  1165. } else if (arg == "--epochs") {
  1166. if (++i >= argc) {
  1167. *invalid_param = true;
  1168. return true;
  1169. }
  1170. params->n_epochs = std::stoi(argv[i]);
  1171. } else if (arg == "--adam-iter") {
  1172. if (++i >= argc) {
  1173. *invalid_param = true;
  1174. return true;
  1175. }
  1176. params->adam_n_iter = std::stoi(argv[i]);
  1177. } else if (arg == "--adam-alpha") {
  1178. if (++i >= argc) {
  1179. *invalid_param = true;
  1180. return true;
  1181. }
  1182. params->adam_alpha = std::stof(argv[i]);
  1183. } else if (arg == "--adam-min-alpha") {
  1184. if (++i >= argc) {
  1185. *invalid_param = true;
  1186. return true;
  1187. }
  1188. params->adam_min_alpha = std::stof(argv[i]);
  1189. } else if (arg == "--adam-decay") {
  1190. if (++i >= argc) {
  1191. *invalid_param = true;
  1192. return true;
  1193. }
  1194. params->adam_decay = std::stof(argv[i]);
  1195. } else if (arg == "--adam-decay-min-ndim") {
  1196. if (++i >= argc) {
  1197. *invalid_param = true;
  1198. return true;
  1199. }
  1200. params->adam_decay_min_ndim = std::stoi(argv[i]);
  1201. } else if (arg == "--adam-beta1") {
  1202. if (++i >= argc) {
  1203. *invalid_param = true;
  1204. return true;
  1205. }
  1206. params->adam_beta1 = std::stof(argv[i]);
  1207. } else if (arg == "--adam-beta2") {
  1208. if (++i >= argc) {
  1209. *invalid_param = true;
  1210. return true;
  1211. }
  1212. params->adam_beta2 = std::stof(argv[i]);
  1213. } else if (arg == "--adam-gclip") {
  1214. if (++i >= argc) {
  1215. *invalid_param = true;
  1216. return true;
  1217. }
  1218. params->adam_gclip = std::stof(argv[i]);
  1219. } else if (arg == "-ngl" || arg == "--n-gpu-layers") {
  1220. if (++i >= argc) {
  1221. *invalid_param = true;
  1222. return true;
  1223. }
  1224. if (llama_supports_gpu_offload()) {
  1225. params->n_gpu_layers = std::stoi(argv[i]);
  1226. } else {
  1227. fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
  1228. fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
  1229. }
  1230. } else if (arg == "-h" || arg == "--help") {
  1231. params->print_usage = true;
  1232. return true;
  1233. } else {
  1234. return false;
  1235. }
  1236. return true;
  1237. }
  1238. void finish_processing_train_args(struct train_params_common * params) {
  1239. if (params->escape) {
  1240. process_escapes(params->sample_start);
  1241. }
  1242. }
  1243. void train_opt_callback(void * vdata, int accum_step, float * sched, bool * cancel) {
  1244. struct train_opt_callback_data * data = (struct train_opt_callback_data *) vdata;
  1245. struct train_params_common * params = data->params;
  1246. struct train_state * train = data->train;
  1247. struct ggml_opt_context * opt = train->opt;
  1248. int n_batch = params->n_batch;
  1249. int n_ctx = params->n_ctx;
  1250. if (accum_step == 0) {
  1251. // time measurement
  1252. int64_t now = ggml_time_ms();
  1253. if (now > data->last_time && opt->iter > data->first_iter) {
  1254. double dt = (double) (now - data->last_time);
  1255. if (data->millis_per_iter == 0.0) {
  1256. data->millis_per_iter = dt;
  1257. } else {
  1258. const double gain = 0.7;
  1259. data->millis_per_iter = data->millis_per_iter*(1.0-gain) + dt*gain;
  1260. }
  1261. }
  1262. double remaining_millis = 0.0;
  1263. if (data->millis_per_iter > 0.0) {
  1264. const int n_iter = params->adam_n_iter;
  1265. const int done_iter = opt->iter - data->first_iter;
  1266. const int remaining_iter = n_iter - done_iter;
  1267. remaining_millis = remaining_iter * data->millis_per_iter;
  1268. }
  1269. // file saving
  1270. const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
  1271. if (save_now) {
  1272. int new_iters = opt->iter - data->last_save_iter;
  1273. train->train_its += new_iters;
  1274. train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx;
  1275. if (data->save_cb) {
  1276. data->save_cb(data->save_data, train);
  1277. }
  1278. data->last_save_iter = opt->iter;
  1279. }
  1280. // exclude file saving from time measurement, by measuring last_time after saving
  1281. data->last_time = ggml_time_ms();
  1282. *sched = learning_schedule(
  1283. opt->iter,
  1284. params->warmup,
  1285. params->cos_decay_steps,
  1286. params->adam_alpha,
  1287. params->adam_min_alpha,
  1288. params->cos_decay_min,
  1289. params->cos_decay_restart,
  1290. params->enable_restart);
  1291. int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f);
  1292. if (impr_plot > 0) impr_plot = 0;
  1293. if (std::isnan(opt->loss_before) || std::isnan(opt->loss_after)) impr_plot = 0;
  1294. printf("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f",
  1295. __func__, opt->iter, std::min(1+train->shuffle_next_sample, train->shuffle_sample_count), train->shuffle_sample_count,
  1296. *sched, opt->loss_after);
  1297. if (data->millis_per_iter > 0) {
  1298. printf(" dt=");
  1299. print_duration(data->millis_per_iter);
  1300. printf(" eta=");
  1301. print_duration(remaining_millis);
  1302. }
  1303. float improvement = opt->loss_before - opt->loss_after;
  1304. const float plot_scale = 10.0f;
  1305. int bar_len = (int)(1 + improvement*plot_scale + 0.5);
  1306. printf(" |");
  1307. for (int i=0; i<bar_len; ++i) {
  1308. printf("-");
  1309. }
  1310. printf(">");
  1311. printf("\n");
  1312. }
  1313. int64_t used_samples = get_example_targets_batch(
  1314. data->lctx,
  1315. data->tokens_input,
  1316. data->target_probs,
  1317. train->shuffle_next_sample,
  1318. data->shuffled_samples_offs,
  1319. data->shuffled_samples_begin,
  1320. data->shuffled_samples_size,
  1321. data->samples_count,
  1322. data->tokens_data,
  1323. data->tokens_size,
  1324. params->separate_with_eos,
  1325. params->separate_with_bos,
  1326. params->fill_with_next_samples,
  1327. params->sample_random_offsets);
  1328. train->train_samples += used_samples;
  1329. train->shuffle_next_sample += used_samples;
  1330. if (train->shuffle_next_sample >= train->shuffle_sample_count) {
  1331. ++train->train_epochs;
  1332. printf("%s: reshuffle samples. completed epochs: %llu\n", __func__, (long long unsigned) train->train_epochs);
  1333. // note: we may have used some samples from the current shuffling more than once
  1334. train->shuffle_rng_state_current = train->shuffle_rng_state_next;
  1335. train->shuffle_rng_state_next = shuffle_samples(
  1336. train->shuffle_rng_state_current,
  1337. data->shuffled_samples_offs,
  1338. data->shuffled_samples_begin,
  1339. data->shuffled_samples_size,
  1340. data->samples_begin,
  1341. data->samples_size,
  1342. data->samples_count);
  1343. train->shuffle_next_sample = 0;
  1344. }
  1345. const bool last_epoch_reached = (params->n_epochs > 0 && (int64_t) train->train_epochs - data->first_epoch >= params->n_epochs);
  1346. if (last_epoch_reached) {
  1347. // allow optimization iteration at last epoch to be completed before canceling
  1348. if (data->iter_at_last_epoch < 0) {
  1349. data->iter_at_last_epoch = opt->iter;
  1350. } else if (opt->iter > data->iter_at_last_epoch) {
  1351. *cancel = true;
  1352. }
  1353. }
  1354. }