train.cpp 64 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496
  1. #include "train.h"
  2. #include "common.h"
  3. #include <random>
  4. #include <sstream>
  5. #include <functional>
  6. struct random_normal_distribution {
  7. std::mt19937 gen;
  8. std::normal_distribution<float> rd;
  9. float min;
  10. float max;
  11. };
  12. struct random_uniform_distribution {
  13. std::mt19937 gen;
  14. std::uniform_real_distribution<float> rd;
  15. };
  16. struct train_state * init_train_state() {
  17. struct train_state * state = new struct train_state;
  18. state->train_its = 0;
  19. state->train_samples = 0;
  20. state->train_tokens = 0;
  21. state->train_epochs = 0;
  22. state->shuffle_samples_hash = 0;
  23. state->shuffle_sample_count = 0;
  24. state->shuffle_next_sample = 0;
  25. state->shuffle_rng_state_current = "";
  26. state->shuffle_rng_state_next = "";
  27. state->opt = new struct ggml_opt_context;
  28. state->opt->ctx = NULL;
  29. state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
  30. state->opt->loss_after = 0.0f;
  31. return state;
  32. }
  33. void free_train_state(struct train_state * state) {
  34. delete state->opt;
  35. delete state;
  36. }
  37. struct random_normal_distribution * init_random_normal_distribution(
  38. int seed, float mean, float std, float min, float max
  39. ) {
  40. struct random_normal_distribution * rnd = (struct random_normal_distribution *) malloc(sizeof(struct random_normal_distribution));
  41. rnd->gen = std::mt19937(seed);
  42. rnd->rd = std::normal_distribution<float>{mean, std};
  43. rnd->min = min;
  44. rnd->max = max;
  45. return rnd;
  46. }
  47. struct random_uniform_distribution * init_random_uniform_distribution(int seed, float min, float max) {
  48. struct random_uniform_distribution * rnd = (struct random_uniform_distribution *) malloc(sizeof(struct random_uniform_distribution));
  49. rnd->gen = std::mt19937(seed);
  50. rnd->rd = std::uniform_real_distribution<float>{min, max};
  51. return rnd;
  52. }
  53. void free_random_normal_distribution (struct random_normal_distribution * rnd) {
  54. free(rnd);
  55. }
  56. void free_random_uniform_distribution(struct random_uniform_distribution * rnd) {
  57. free(rnd);
  58. }
  59. struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) {
  60. float scale = 1.0f; // xavier
  61. switch (tensor->n_dims) {
  62. case 1:
  63. scale /= sqrtf((float) tensor->ne[0]);
  64. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  65. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
  66. *dst = scale * frand_normal(rnd);
  67. }
  68. break;
  69. case 2:
  70. scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
  71. for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
  72. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  73. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
  74. *dst = scale * frand_normal(rnd);
  75. }
  76. }
  77. break;
  78. case 3:
  79. scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
  80. for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
  81. for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
  82. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  83. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
  84. *dst = scale * frand_normal(rnd);
  85. }
  86. }
  87. }
  88. break;
  89. case 4:
  90. scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
  91. for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
  92. for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
  93. for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
  94. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  95. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
  96. *dst = scale * frand_normal(rnd);
  97. }
  98. }
  99. }
  100. }
  101. break;
  102. default:
  103. die("Unsupported tensor->n_dims");
  104. };
  105. return tensor;
  106. }
  107. struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd) {
  108. switch (tensor->n_dims) {
  109. case 1:
  110. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  111. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
  112. *dst = frand_uniform(rnd);
  113. }
  114. break;
  115. case 2:
  116. for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
  117. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  118. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
  119. *dst = frand_uniform(rnd);
  120. }
  121. }
  122. break;
  123. case 3:
  124. for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
  125. for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
  126. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  127. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
  128. *dst = frand_uniform(rnd);
  129. }
  130. }
  131. }
  132. break;
  133. case 4:
  134. for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
  135. for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
  136. for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
  137. for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
  138. float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
  139. *dst = frand_uniform(rnd);
  140. }
  141. }
  142. }
  143. }
  144. break;
  145. default:
  146. die("Unsupported tensor->n_dims");
  147. };
  148. return tensor;
  149. }
  150. float frand() {
  151. return (float)rand()/((float)(RAND_MAX) + 1.0f);
  152. }
  153. float frand_normal(struct random_normal_distribution * rnd) {
  154. return fclamp(rnd->rd(rnd->gen), rnd->min, rnd->max);
  155. }
  156. float frand_uniform(struct random_uniform_distribution * rnd) {
  157. return rnd->rd(rnd->gen);
  158. }
  159. int clamp(const int v, const int min, const int max) {
  160. return ((v < min) ? (min) : (v > max) ? (max) : v);
  161. }
  162. float fclamp(const float v, const float min, const float max) {
  163. return ((v < min) ? (min) : (v > max) ? (max) : v);
  164. }
  165. void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) {
  166. GGML_ASSERT(tensor->n_dims == 1);
  167. GGML_ASSERT(tensor->ne[0] == ne0);
  168. }
  169. void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1) {
  170. GGML_ASSERT(tensor->n_dims == 2);
  171. GGML_ASSERT(tensor->ne[0] == ne0);
  172. GGML_ASSERT(tensor->ne[1] == ne1);
  173. }
  174. void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2) {
  175. GGML_ASSERT(tensor->n_dims == 3);
  176. GGML_ASSERT(tensor->ne[0] == ne0);
  177. GGML_ASSERT(tensor->ne[1] == ne1);
  178. GGML_ASSERT(tensor->ne[2] == ne2);
  179. }
  180. void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
  181. GGML_ASSERT(tensor->n_dims == 4);
  182. GGML_ASSERT(tensor->ne[0] == ne0);
  183. GGML_ASSERT(tensor->ne[1] == ne1);
  184. GGML_ASSERT(tensor->ne[2] == ne2);
  185. GGML_ASSERT(tensor->ne[3] == ne3);
  186. }
  187. int64_t get_example_targets_batch(
  188. struct llama_context * lctx,
  189. struct ggml_tensor * tokens_input,
  190. struct ggml_tensor * target_probs,
  191. int64_t example_id,
  192. const size_t * samples_offs,
  193. const size_t * samples_begin,
  194. const size_t * samples_size,
  195. size_t samples_count,
  196. const llama_token * train_data,
  197. size_t n_train_data,
  198. bool separate_with_eos,
  199. bool separate_with_bos,
  200. bool fill_with_next_samples,
  201. bool sample_random_offsets
  202. ) {
  203. GGML_ASSERT(samples_count > 0);
  204. GGML_ASSERT(tokens_input->n_dims == 2);
  205. GGML_ASSERT(target_probs->n_dims == 3);
  206. int64_t n_vocab = target_probs->ne[0];
  207. int64_t n_tokens = tokens_input->ne[0];
  208. int64_t n_batch = tokens_input->ne[1];
  209. GGML_ASSERT(n_vocab == target_probs->ne[0]);
  210. GGML_ASSERT(n_tokens == target_probs->ne[1]);
  211. GGML_ASSERT(n_batch == target_probs->ne[2]);
  212. int64_t used_samples = 0;
  213. ggml_set_f32(target_probs, 0.0f);
  214. llama_token bos = llama_token_bos(llama_get_model(lctx));
  215. llama_token eos = llama_token_eos(llama_get_model(lctx));
  216. // printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
  217. for (int k=0; k<n_batch; ++k) {
  218. // printf("%s: batch %d\n", __func__, k);
  219. size_t sample_idx = (example_id + used_samples) % samples_count;
  220. size_t sample_offs = sample_random_offsets ? samples_offs[sample_idx] : 0;
  221. size_t sample_begin = samples_begin[sample_idx];
  222. size_t sample_size = samples_size[sample_idx];
  223. ++used_samples;
  224. // printf("%s: sample_idx=%zu sample=%zu\n", __func__, sample_idx, sample);
  225. GGML_ASSERT(sample_begin+sample_size-1 < n_train_data);
  226. ggml_set_i32_nd(tokens_input, 0, k, 0, 0, bos);
  227. bool sample_separation_eos = !separate_with_eos;
  228. bool sample_separation_bos = !separate_with_bos;
  229. for (int64_t i=0; i<n_tokens; ++i) {
  230. llama_token token = eos;
  231. if (sample_offs >= sample_size && fill_with_next_samples) {
  232. if (!sample_separation_eos) {
  233. // insert eos token to separate samples
  234. sample_separation_eos = true;
  235. } else if (!sample_separation_bos) {
  236. // insert bos token to separate samples
  237. sample_separation_bos = true;
  238. token = bos;
  239. } else {
  240. // sample separation is done, continue with next sample
  241. sample_separation_eos = !separate_with_eos;
  242. sample_separation_bos = !separate_with_bos;
  243. sample_offs = 0;
  244. sample_idx = (example_id + used_samples) % samples_count;
  245. sample_begin = samples_begin[sample_idx];
  246. sample_size = samples_size[sample_idx];
  247. ++used_samples;
  248. }
  249. }
  250. // note: no else-if here
  251. if (sample_offs < sample_size) {
  252. token = clamp(train_data[sample_begin+sample_offs], 0, (llama_token) (n_vocab - 1));
  253. ++sample_offs;
  254. }
  255. ggml_set_f32_nd(target_probs, token, (int) i, (int) k, 0, +1.0f);
  256. if (i+1<n_tokens) {
  257. ggml_set_i32_nd(tokens_input, (int) (i + 1), (int) k, 0, 0, token);
  258. }
  259. }
  260. }
  261. return used_samples;
  262. }
  263. void mt19937_set_state(std::mt19937& rng, const std::string& rng_state) {
  264. std::stringstream s_rng_state;
  265. s_rng_state.imbue(std::locale::classic());
  266. s_rng_state.exceptions(std::stringstream::failbit);
  267. s_rng_state.str(rng_state);
  268. s_rng_state >> rng;
  269. }
  270. std::string mt19937_get_state(const std::mt19937& rng) {
  271. std::stringstream s_rng_state;
  272. s_rng_state.imbue(std::locale::classic());
  273. s_rng_state << rng;
  274. return s_rng_state.str();
  275. }
  276. std::string mt19937_seed_to_state(unsigned seed) {
  277. std::mt19937 rng(seed);
  278. return mt19937_get_state(rng);
  279. }
  280. std::string shuffle_samples(
  281. const std::string & rng_state,
  282. size_t * shuffled_offs,
  283. size_t * shuffled_begins,
  284. size_t * shuffled_sizes,
  285. const size_t * begins,
  286. const size_t * sizes,
  287. size_t count) {
  288. if (count == 0) return rng_state;
  289. std::mt19937 rng;
  290. mt19937_set_state(rng, rng_state);
  291. // sort indices by random value for each index
  292. std::vector<size_t> idcs;
  293. {
  294. std::vector<unsigned> rnd;
  295. idcs.resize(count);
  296. rnd.resize(count);
  297. for (unsigned i=0; i<count; ++i) {
  298. idcs[i] = i;
  299. rnd[i] = rng();
  300. }
  301. std::sort(idcs.begin(), idcs.end(), [&rnd](size_t a, size_t b){
  302. // stable sort for reproducibility
  303. return (rnd[a] == rnd[b]) ? (a < b) : (rnd[a] < rnd[b]);
  304. });
  305. }
  306. // create random offsets
  307. for (unsigned i=0; i<count; ++i) {
  308. shuffled_offs[i] = (size_t) ((sizes[idcs[i]] - 1) * ((double) rng() / (double) (rng.max()-1)));
  309. }
  310. // reorder begins and sizes by sorted indices
  311. for (unsigned i=0; i<count; ++i) {
  312. shuffled_begins[i] = begins[idcs[i]];
  313. }
  314. for (unsigned i=0; i<count; ++i) {
  315. shuffled_sizes[i] = sizes[idcs[i]];
  316. }
  317. return mt19937_get_state(rng);
  318. }
  319. size_t hash_combine(size_t h1, size_t h2) {
  320. return h1 ^ (h2 << 1);
  321. }
  322. size_t compute_samples_hash(const char* fn, const size_t* samples_begin, const size_t* samples_size, size_t sample_count) {
  323. std::hash<std::string> h_string;
  324. std::hash<unsigned long long> h_ull;
  325. size_t h = h_string(std::string(fn));
  326. h = hash_combine(h, h_ull((unsigned long long) sample_count));
  327. for (size_t i=0; i< sample_count; ++i) {
  328. h = hash_combine(h, h_ull((unsigned long long) samples_begin[i]));
  329. h = hash_combine(h, h_ull((unsigned long long) samples_size[i]));
  330. }
  331. return h;
  332. }
  333. std::string replace_str(const char * s, const char * needle, const char * replacement) {
  334. std::string str = s;
  335. size_t pos = str.find(needle);
  336. if (pos != std::string::npos) {
  337. str.replace(pos, strlen(needle), replacement);
  338. }
  339. return str;
  340. }
  341. void print_duration(double fmillis) {
  342. if (fmillis < 1000.0f) {
  343. printf("%.1fms", (float) fmillis);
  344. return;
  345. }
  346. const int64_t one_sec = 1000;
  347. const int64_t one_min = one_sec * 60;
  348. const int64_t one_hour = one_min * 60;
  349. const int64_t one_day = one_hour * 24;
  350. int64_t millis = (int64_t) fmillis;
  351. int64_t days = millis/one_day;
  352. int64_t hours = (millis - days*one_day)/one_hour;
  353. int64_t minutes = (millis - days*one_day - hours*one_hour)/one_min;
  354. int64_t seconds = (millis - days*one_day - hours*one_hour - minutes*one_min)/one_sec;
  355. // to print int64_t either cast to (long long int) or use macro PRId64 from <inttypes.h>
  356. if (days > 0) {
  357. printf("%lldd ", (long long int) days);
  358. }
  359. printf("%02lld:%02lld:%02lld", (long long int) hours, (long long int) minutes, (long long int) seconds);
  360. }
  361. float cosine_decay(int64_t step, int64_t decay_steps, float minimum) {
  362. if (step > decay_steps) {
  363. step = decay_steps;
  364. }
  365. const float cosine_decay = 0.50f*(1.0f + cosf(3.14159265359f*step/decay_steps));
  366. const float decay = (1 - minimum)*cosine_decay + minimum;
  367. return decay;
  368. }
  369. float cosine_decay_restart(int64_t step, int64_t decay_steps, float minimum, float restart_step_mult) {
  370. while (step > decay_steps) {
  371. step -= decay_steps;
  372. decay_steps = (int64_t) (restart_step_mult * decay_steps);
  373. }
  374. return cosine_decay(step, decay_steps, minimum);
  375. }
  376. float learning_schedule(
  377. int64_t step,
  378. int64_t warmup_steps,
  379. int64_t cos_decay_steps,
  380. float learning_rate,
  381. float overall_minimum,
  382. float cos_decay_minimum,
  383. float cos_decay_restart_step_mult,
  384. bool enable_restart) {
  385. float result =
  386. (step < warmup_steps)
  387. ? (float) step / (float) warmup_steps
  388. : enable_restart
  389. ? cosine_decay_restart(
  390. step - warmup_steps,
  391. cos_decay_steps,
  392. cos_decay_minimum,
  393. cos_decay_restart_step_mult)
  394. : cosine_decay(
  395. step,
  396. cos_decay_steps,
  397. cos_decay_minimum);
  398. float min = overall_minimum / learning_rate;
  399. result = min + result * (1.0f - min);
  400. return result;
  401. }
  402. static bool are_same_layout(struct ggml_tensor * a, struct ggml_tensor * b) {
  403. GGML_ASSERT(a != NULL);
  404. GGML_ASSERT(b != NULL);
  405. GGML_ASSERT(a->type == b->type);
  406. GGML_ASSERT(ggml_are_same_shape(a, b));
  407. GGML_ASSERT(ggml_is_contiguous(a) && ggml_is_contiguous(b));
  408. return true;
  409. }
  410. void copy_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name) {
  411. if (dst == NULL) {
  412. return;
  413. }
  414. struct ggml_tensor * t = ggml_get_tensor(ctx, name);
  415. GGML_ASSERT(are_same_layout(dst, t));
  416. memcpy(dst->data, t->data, ggml_nbytes(t));
  417. if (strlen(ggml_get_name(dst)) == 0) {
  418. ggml_set_name(dst, name);
  419. }
  420. }
  421. // gguf constants
  422. static const char * LLM_KV_OPTIMIZER_TYPE = "optimizer.type";
  423. static const char * LLM_KV_OPTIMIZER_TYPE_ADAM = "adam";
  424. static const char * LLM_KV_OPTIMIZER_TYPE_LBFGS = "lbfgs";
  425. static const char * LLM_KV_OPTIMIZER_FILE_VERSION = "optimizer.file_version";
  426. static const char * LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT = "optimizer.convergence_past_count";
  427. static const char * LLM_KV_OPTIMIZER_PARAMETER_COUNT = "optimizer.parameter_count";
  428. static const char * LLM_KV_OPTIMIZER_ITERATION_COUNT = "optimizer.iteration_count";
  429. static const char * LLM_KV_OPTIMIZER_JUST_INITIALIZED = "optimizer.just_initialized";
  430. static const char * LLM_KV_OPTIMIZER_ADAM_BEST_LOSS = "optimizer.adam.best_loss";
  431. static const char * LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS = "optimizer.adam.previous_loss";
  432. static const char * LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT = "optimizer.adam.no_improvement_count";
  433. static const char * LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT = "optimizer.lbfgs.approx_hessian_count";
  434. static const char * LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS = "optimizer.lbfgs.best_loss";
  435. static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP = "optimizer.lbfgs.line_search_step";
  436. static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J = "optimizer.lbfgs.line_search_j";
  437. static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K = "optimizer.lbfgs.line_search_k";
  438. static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END = "optimizer.lbfgs.line_search_end";
  439. static const char * LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT = "optimizer.lbfgs.no_improvement_count";
  440. static const char * LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS = "optimizer.adam.first_moments";
  441. static const char * LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS = "optimizer.adam.second_moments";
  442. static const char * LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES = "optimizer.adam.past_loss_values";
  443. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS = "optimizer.lbfgs.current_parameters";
  444. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS = "optimizer.lbfgs.previous_parameters";
  445. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS = "optimizer.lbfgs.current_gradients";
  446. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS = "optimizer.lbfgs.previous_gradients";
  447. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION = "optimizer.lbfgs.search_direction";
  448. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES = "optimizer.lbfgs.past_loss_values";
  449. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA = "optimizer.lbfgs.memory_alpha";
  450. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.lbfgs.memory_ys";
  451. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s";
  452. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y";
  453. static const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version";
  454. static const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count";
  455. static const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count";
  456. static const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count";
  457. static const char * LLM_KV_TRAINING_EPOCH_COUNT = "training.epoch_count";
  458. static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH = "training.shuffle.samples_hash";
  459. static const char * LLM_KV_TRAINING_SHUFFLE_RNG_STATE = "training.shuffle.rng_state";
  460. static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT = "training.shuffle.sample_count";
  461. static const char * LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE = "training.shuffle.next_sample";
  462. #define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
  463. { \
  464. const std::string skey(key); \
  465. const int kid = gguf_find_key(ctx, skey.c_str()); \
  466. if (kid >= 0) { \
  467. enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
  468. if (ktype != (type)) { \
  469. die_fmt("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype)); \
  470. } \
  471. (dst) = func(ctx, kid); \
  472. } else if (req) { \
  473. die_fmt("key not found in model: %s", skey.c_str()); \
  474. } \
  475. }
  476. void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt) {
  477. // NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
  478. uint32_t file_version;
  479. GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_FILE_VERSION);
  480. GGML_ASSERT(file_version == 0);
  481. GGUF_GET_KEY(fctx, opt->params.past, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT);
  482. GGUF_GET_KEY(fctx, opt->iter, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ITERATION_COUNT);
  483. GGUF_GET_KEY(fctx, opt->just_initialized, gguf_get_val_bool, GGUF_TYPE_BOOL, true, LLM_KV_OPTIMIZER_JUST_INITIALIZED);
  484. uint64_t nx;
  485. GGUF_GET_KEY(fctx, nx, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_OPTIMIZER_PARAMETER_COUNT);
  486. opt->nx = (size_t) nx;
  487. // don't call ggml_opt_init until optimizer type and optimizer specific parameters are know
  488. std::string opt_type;
  489. GGUF_GET_KEY(fctx, opt_type, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_OPTIMIZER_TYPE);
  490. if (opt_type == LLM_KV_OPTIMIZER_TYPE_ADAM) {
  491. opt->params.type = GGML_OPT_ADAM;
  492. GGUF_GET_KEY(fctx, opt->adam.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS);
  493. GGUF_GET_KEY(fctx, opt->adam.fx_prev, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS);
  494. GGUF_GET_KEY(fctx, opt->adam.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT);
  495. ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
  496. copy_tensor_by_name(opt->adam.m, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
  497. copy_tensor_by_name(opt->adam.v, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
  498. copy_tensor_by_name(opt->adam.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
  499. } else if (opt_type == LLM_KV_OPTIMIZER_TYPE_LBFGS) {
  500. opt->params.type = GGML_OPT_LBFGS;
  501. GGUF_GET_KEY(fctx, opt->params.lbfgs.m, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT);
  502. GGUF_GET_KEY(fctx, opt->lbfgs.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS);
  503. GGUF_GET_KEY(fctx, opt->lbfgs.step, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP);
  504. GGUF_GET_KEY(fctx, opt->lbfgs.j, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J);
  505. GGUF_GET_KEY(fctx, opt->lbfgs.k, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K);
  506. GGUF_GET_KEY(fctx, opt->lbfgs.end, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END);
  507. GGUF_GET_KEY(fctx, opt->lbfgs.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT);
  508. ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
  509. copy_tensor_by_name(opt->lbfgs.x, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
  510. copy_tensor_by_name(opt->lbfgs.xp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
  511. copy_tensor_by_name(opt->lbfgs.g, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
  512. copy_tensor_by_name(opt->lbfgs.gp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
  513. copy_tensor_by_name(opt->lbfgs.d, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
  514. copy_tensor_by_name(opt->lbfgs.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
  515. copy_tensor_by_name(opt->lbfgs.lmal, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
  516. copy_tensor_by_name(opt->lbfgs.lmys, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
  517. copy_tensor_by_name(opt->lbfgs.lms, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
  518. copy_tensor_by_name(opt->lbfgs.lmy, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
  519. } else {
  520. die("unknown optimizer type\n");
  521. }
  522. }
  523. void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt) {
  524. gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_FILE_VERSION, 0);
  525. gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT, opt->params.past);
  526. gguf_set_val_u64(fctx, LLM_KV_OPTIMIZER_PARAMETER_COUNT, (uint64_t) opt->nx);
  527. gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ITERATION_COUNT, opt->iter);
  528. gguf_set_val_bool(fctx, LLM_KV_OPTIMIZER_JUST_INITIALIZED, opt->just_initialized);
  529. switch (opt->params.type) {
  530. case GGML_OPT_ADAM:
  531. {
  532. gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_ADAM);
  533. gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS, opt->adam.fx_best);
  534. gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS, opt->adam.fx_prev);
  535. gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT, opt->adam.n_no_improvement);
  536. ggml_set_name(opt->adam.m, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
  537. ggml_set_name(opt->adam.v, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
  538. if (opt->adam.pf) {
  539. ggml_set_name(opt->adam.pf, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
  540. }
  541. gguf_add_tensor(fctx, opt->adam.m);
  542. gguf_add_tensor(fctx, opt->adam.v);
  543. if (opt->adam.pf) {
  544. gguf_add_tensor(fctx, opt->adam.pf);
  545. }
  546. } break;
  547. case GGML_OPT_LBFGS:
  548. {
  549. gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS);
  550. gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, opt->params.lbfgs.m);
  551. gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS, opt->lbfgs.fx_best);
  552. gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP, opt->lbfgs.step);
  553. gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J, opt->lbfgs.j);
  554. gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K, opt->lbfgs.k);
  555. gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END, opt->lbfgs.end);
  556. gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT, opt->lbfgs.n_no_improvement);
  557. ggml_set_name(opt->lbfgs.x, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
  558. ggml_set_name(opt->lbfgs.xp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
  559. ggml_set_name(opt->lbfgs.g, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
  560. ggml_set_name(opt->lbfgs.gp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
  561. ggml_set_name(opt->lbfgs.d, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
  562. if (opt->lbfgs.pf) {
  563. ggml_set_name(opt->lbfgs.pf, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
  564. }
  565. ggml_set_name(opt->lbfgs.lmal, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
  566. ggml_set_name(opt->lbfgs.lmys, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
  567. ggml_set_name(opt->lbfgs.lms, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
  568. ggml_set_name(opt->lbfgs.lmy, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
  569. gguf_add_tensor(fctx, opt->lbfgs.x);
  570. gguf_add_tensor(fctx, opt->lbfgs.xp);
  571. gguf_add_tensor(fctx, opt->lbfgs.g);
  572. gguf_add_tensor(fctx, opt->lbfgs.gp);
  573. gguf_add_tensor(fctx, opt->lbfgs.d);
  574. if (opt->lbfgs.pf) {
  575. gguf_add_tensor(fctx, opt->lbfgs.pf);
  576. }
  577. gguf_add_tensor(fctx, opt->lbfgs.lmal);
  578. gguf_add_tensor(fctx, opt->lbfgs.lmys);
  579. gguf_add_tensor(fctx, opt->lbfgs.lms);
  580. gguf_add_tensor(fctx, opt->lbfgs.lmy);
  581. } break;
  582. }
  583. }
  584. bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train) {
  585. if (gguf_find_key(fctx, LLM_KV_TRAINING_FILE_VERSION) < 0) {
  586. return false;
  587. }
  588. uint32_t file_version;
  589. GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION);
  590. GGML_ASSERT(file_version <= 1);
  591. if (file_version == 0) {
  592. GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT);
  593. GGUF_GET_KEY(fctx, train->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT);
  594. GGUF_GET_KEY(fctx, train->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT);
  595. } else if (file_version == 1) {
  596. GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_ITERATION_COUNT);
  597. GGUF_GET_KEY(fctx, train->train_samples, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_SAMPLE_COUNT);
  598. GGUF_GET_KEY(fctx, train->train_tokens, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_TOKEN_COUNT);
  599. GGUF_GET_KEY(fctx, train->train_epochs, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_EPOCH_COUNT);
  600. GGUF_GET_KEY(fctx, train->shuffle_samples_hash, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH);
  601. GGUF_GET_KEY(fctx, train->shuffle_rng_state_current, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_SHUFFLE_RNG_STATE);
  602. GGUF_GET_KEY(fctx, train->shuffle_sample_count, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT);
  603. GGUF_GET_KEY(fctx, train->shuffle_next_sample, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE);
  604. }
  605. load_opt_context_gguf(fctx, f_ggml_ctx, train->opt);
  606. return true;
  607. }
  608. void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train) {
  609. gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 1);
  610. gguf_set_val_u64(fctx, LLM_KV_TRAINING_ITERATION_COUNT, train->train_its);
  611. gguf_set_val_u64(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, train->train_samples);
  612. gguf_set_val_u64(fctx, LLM_KV_TRAINING_TOKEN_COUNT, train->train_tokens);
  613. gguf_set_val_u64(fctx, LLM_KV_TRAINING_EPOCH_COUNT, train->train_epochs);
  614. gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH, (uint64_t) train->shuffle_samples_hash);
  615. gguf_set_val_str(fctx, LLM_KV_TRAINING_SHUFFLE_RNG_STATE, train->shuffle_rng_state_current.c_str());
  616. gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT, (uint64_t) train->shuffle_sample_count);
  617. gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE, (uint64_t) train->shuffle_next_sample);
  618. save_opt_context_gguf(fctx, train->opt);
  619. }
  620. struct llama_file {
  621. // use FILE * so we don't have to re-open the file to mmap
  622. FILE * fp;
  623. size_t size;
  624. llama_file(const char * fname, const char * mode) {
  625. fp = std::fopen(fname, mode);
  626. if (fp == NULL) {
  627. size = 0;
  628. } else {
  629. seek(0, SEEK_END);
  630. size = tell();
  631. seek(0, SEEK_SET);
  632. }
  633. }
  634. size_t tell() const {
  635. #ifdef _WIN32
  636. __int64 ret = _ftelli64(fp);
  637. #else
  638. long ret = std::ftell(fp);
  639. #endif
  640. GGML_ASSERT(ret != -1); // this really shouldn't fail
  641. return (size_t) ret;
  642. }
  643. void seek(size_t offset, int whence) {
  644. #ifdef _WIN32
  645. int ret = _fseeki64(fp, (__int64) offset, whence);
  646. #else
  647. int ret = std::fseek(fp, (long) offset, whence);
  648. #endif
  649. GGML_ASSERT(ret == 0); // same
  650. }
  651. void read_raw(void * ptr, size_t size) {
  652. if (size == 0) {
  653. return;
  654. }
  655. errno = 0;
  656. std::size_t ret = std::fread(ptr, size, 1, fp);
  657. if (ferror(fp)) {
  658. die_fmt("read error: %s", strerror(errno));
  659. }
  660. if (ret != 1) {
  661. die("unexpectedly reached end of file");
  662. }
  663. }
  664. std::uint32_t read_u32() {
  665. std::uint32_t ret;
  666. read_raw(&ret, sizeof(ret));
  667. return ret;
  668. }
  669. std::string read_string(std::uint32_t len) {
  670. std::vector<char> chars(len);
  671. read_raw(chars.data(), len);
  672. return std::string(chars.data(), len);
  673. }
  674. void write_raw(const void * ptr, size_t size) {
  675. if (size == 0) {
  676. return;
  677. }
  678. errno = 0;
  679. size_t ret = std::fwrite(ptr, size, 1, fp);
  680. if (ret != 1) {
  681. die_fmt("write error: %s", strerror(errno));
  682. }
  683. }
  684. void write_u32(std::uint32_t val) {
  685. write_raw(&val, sizeof(val));
  686. }
  687. ~llama_file() {
  688. if (fp) {
  689. std::fclose(fp);
  690. }
  691. }
  692. };
  693. static size_t utf8_len(char src) {
  694. const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
  695. uint8_t highbits = static_cast<uint8_t>(src) >> 4;
  696. return lookup[highbits];
  697. }
  698. // mark each byte with its utf8 unit number.
  699. // returns the number of utf8 characters.
  700. // e.g. when bytes == '\x61\xD0\xB0\x62',
  701. // then utf8_units will become [0,0,1,0]
  702. // utf8_nunits will become [1,2,2,1] and 3 is returned.
  703. // bytes where utf8_units is zero, are the begin of an utf8 character.
  704. static size_t mark_utf8_units(const char* bytes, int * utf8_units, int * utf8_nunits, size_t count) {
  705. size_t offs = 0;
  706. size_t count_utf8 = 0;
  707. while(offs < count) {
  708. int len = (int) utf8_len(bytes[offs]);
  709. for (int i=0; i<len; ++i) {
  710. utf8_units[offs+i] = i;
  711. utf8_nunits[offs+i] = len;
  712. }
  713. offs += len;
  714. ++count_utf8;
  715. }
  716. return count_utf8;
  717. }
  718. size_t tokenize_file(
  719. struct llama_context * lctx,
  720. const char * filename,
  721. const std::string & sample_start,
  722. bool include_sample_start,
  723. bool overlapping_samples,
  724. unsigned context_length,
  725. std::vector<llama_token> & out_tokens,
  726. std::vector<size_t> & out_samples_begin,
  727. std::vector<size_t> & out_samples_size) {
  728. struct llama_file f(filename, "rb");
  729. if (f.size == 0) {
  730. out_tokens.clear();
  731. out_samples_begin.clear();
  732. out_samples_size.clear();
  733. printf("%s: warning: empty or not existing training data file '%s'\n",
  734. __func__, filename);
  735. return out_tokens.size();
  736. }
  737. // account for possible leading whitespace that will be added by tokenizer
  738. // e.g. '\t' will be tokenized by llama spm tokenizer to [29871, 12]
  739. const int n_max_tokens_overhead = 1;
  740. std::vector<char> buf;
  741. buf.resize(f.size);
  742. f.read_raw(buf.data(), f.size);
  743. std::vector<int> utf8_units;
  744. std::vector<int> utf8_nunits;
  745. utf8_units.resize(buf.size());
  746. utf8_nunits.resize(buf.size());
  747. mark_utf8_units(buf.data(), utf8_units.data(), utf8_nunits.data(), buf.size());
  748. if (sample_start.size() == 0) {
  749. // tokenize all data at once
  750. out_tokens.resize(buf.size() + n_max_tokens_overhead);
  751. int n_tokens = llama_tokenize(
  752. llama_get_model(lctx),
  753. buf.data(),
  754. (int) buf.size(),
  755. out_tokens.data(),
  756. (int) out_tokens.size(),
  757. false, false);
  758. if (n_tokens < 0) {
  759. out_tokens.resize(-n_tokens);
  760. n_tokens = llama_tokenize(
  761. llama_get_model(lctx),
  762. buf.data(),
  763. (int) buf.size(),
  764. out_tokens.data(),
  765. (int) out_tokens.size(),
  766. false, false);
  767. }
  768. if (n_tokens >= 0) {
  769. out_tokens.resize(n_tokens);
  770. }
  771. // generate sample starts at all token positions
  772. out_samples_begin.clear();
  773. out_samples_begin.push_back(0);
  774. out_samples_size.push_back(std::min((size_t) context_length, out_tokens.size()));
  775. size_t end = (out_tokens.size() >= context_length) ? (out_tokens.size() - context_length) : 0;
  776. for (size_t sample_begin = 1; sample_begin < end; ++sample_begin) {
  777. out_samples_begin.push_back(sample_begin);
  778. out_samples_size.push_back(context_length);
  779. }
  780. } else {
  781. // split data into samples and tokenize each sample
  782. std::string data_str(buf.data(), buf.size());
  783. out_samples_begin.clear();
  784. out_samples_size.clear();
  785. out_tokens.clear();
  786. // find all positions of pattern sample_start
  787. size_t sample_begin = data_str.find(sample_start, 0);
  788. while (sample_begin != std::string::npos) {
  789. out_samples_begin.push_back(sample_begin);
  790. const size_t search_start = sample_begin + sample_start.size();
  791. sample_begin = data_str.find(sample_start, search_start);
  792. }
  793. if (out_samples_begin.size() == 0) {
  794. printf("%s: warning: sample start pattern '%s' not found. inserting single sample at data begin\n",
  795. __func__, sample_start.c_str());
  796. out_samples_begin.push_back(0);
  797. }
  798. out_samples_size.resize(out_samples_begin.size(), 0);
  799. std::vector<char> buf_sample;
  800. std::vector<llama_token> tok_sample;
  801. const size_t sample_begin_offset = (include_sample_start ? 0 : sample_start.size());
  802. size_t found_too_big_sample = 0;
  803. size_t found_too_small_sample = 0;
  804. size_t found_empty_sample = 0;
  805. size_t found_min_sample_size = SIZE_MAX;
  806. size_t found_max_sample_size = 0;
  807. size_t max_token_text_size = 0;
  808. int n_vocab = llama_n_vocab(llama_get_model(lctx));
  809. for (llama_token token=0; token < n_vocab; ++token) {
  810. max_token_text_size = std::max(
  811. max_token_text_size,
  812. strlen(llama_token_get_text(llama_get_model(lctx), token)));
  813. }
  814. // upper bound of context byte length.
  815. // strings with this byte length should always tokenize to at least context_length tokens.
  816. size_t context_byte_len = max_token_text_size*context_length;
  817. for (unsigned i=0; i<out_samples_begin.size(); ++i) {
  818. // determine sample begin and end from pattern positions
  819. size_t sample_begin = out_samples_begin[i] + sample_begin_offset;
  820. size_t sample_end = overlapping_samples
  821. ? std::min(
  822. data_str.size(),
  823. sample_begin + context_byte_len)
  824. : (i+1 < out_samples_begin.size()
  825. ? out_samples_begin[i+1]
  826. : data_str.size());
  827. if (sample_end < utf8_units.size() && utf8_units[sample_end] > 0) {
  828. // sample end is in the middle of an utf8 character.
  829. // advance sample_end to the begin of the next utf8 character.
  830. sample_end += utf8_nunits[sample_end] - utf8_units[sample_end];
  831. }
  832. size_t sample_size = sample_end - sample_begin;
  833. if (sample_size == 0) {
  834. ++found_empty_sample;
  835. }
  836. if (sample_size > 0) {
  837. // llama_tokenize expects zero terminated string,
  838. // copy sample into buffer and zero terminate it.
  839. buf_sample.resize(sample_size);
  840. memcpy(buf_sample.data(), data_str.data() + sample_begin, sample_size);
  841. // printf("sample: '%s'\n", buf_sample.data());
  842. // tokenize the sample
  843. tok_sample.resize(buf_sample.size() + n_max_tokens_overhead);
  844. int n_tokens = llama_tokenize(llama_get_model(lctx),
  845. buf_sample.data(),
  846. (int) buf_sample.size(),
  847. tok_sample.data(),
  848. (int) tok_sample.size(),
  849. false, false);
  850. if (n_tokens < 0) {
  851. tok_sample.resize(-n_tokens);
  852. n_tokens = llama_tokenize(llama_get_model(lctx),
  853. buf_sample.data(),
  854. (int) buf_sample.size(),
  855. tok_sample.data(),
  856. (int) tok_sample.size(),
  857. false, false);
  858. GGML_ASSERT(n_tokens >= 0);
  859. }
  860. GGML_ASSERT(n_tokens <= (int) tok_sample.size());
  861. if ((size_t) n_tokens > context_length) {
  862. ++found_too_big_sample;
  863. } else if ((size_t) n_tokens < context_length) {
  864. ++found_too_small_sample;
  865. }
  866. found_max_sample_size = std::max(found_max_sample_size, (size_t) n_tokens);
  867. found_min_sample_size = std::min(found_min_sample_size, (size_t) n_tokens);
  868. // write out tokens, start and size of sample
  869. // overwrite the string start position with the token start position
  870. out_samples_begin[i] = out_tokens.size();
  871. out_samples_size[i] = (size_t) n_tokens;
  872. out_tokens.insert(out_tokens.end(), tok_sample.begin(), tok_sample.begin() + n_tokens);
  873. } else {
  874. out_samples_begin[i] = out_tokens.size();
  875. out_samples_size[i] = 0;
  876. }
  877. }
  878. if (found_too_big_sample > 0) {
  879. printf("%s: warning: found %zu samples (max length %zu) that exceed context length of %u. samples will be cut off.\n",
  880. __func__, found_too_big_sample, found_max_sample_size, context_length);
  881. }
  882. if (found_too_small_sample > 0) {
  883. printf("%s: warning: found %zu samples (min length %zu) that are shorter than context length of %u.\n",
  884. __func__, found_too_small_sample, found_min_sample_size, context_length);
  885. }
  886. if (found_empty_sample) {
  887. printf("%s: warning: found %zu empty samples.\n",
  888. __func__, found_empty_sample);
  889. }
  890. }
  891. printf("%s: total number of samples: %zu\n",
  892. __func__, out_samples_begin.size());
  893. GGML_ASSERT(out_samples_begin.size() == out_samples_size.size());
  894. return out_tokens.size();
  895. }
  896. std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration) {
  897. std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
  898. return replace_str(filename, pattern_it, sit.c_str());
  899. }
  900. struct train_params_common get_default_train_params_common() {
  901. struct train_params_common params;
  902. params.fn_train_data = "shakespeare.txt";
  903. params.fn_checkpoint_in = "checkpoint.gguf";
  904. params.fn_checkpoint_out = "checkpoint-ITERATION.gguf";
  905. params.pattern_fn_it = "ITERATION";
  906. params.fn_latest = "LATEST";
  907. params.print_usage = false;
  908. params.save_every = 10;
  909. params.seed = -1;
  910. params.n_ctx = 128;
  911. params.n_threads = 6;
  912. params.n_batch = 8;
  913. params.n_gradient_accumulation = 1;
  914. params.n_epochs = -1;
  915. params.custom_n_ctx = false;
  916. params.use_flash = true;
  917. params.use_checkpointing = true;
  918. params.sample_start = "";
  919. params.include_sample_start = false;
  920. params.escape = false;
  921. params.overlapping_samples = false;
  922. params.fill_with_next_samples = false;
  923. params.separate_with_eos = false;
  924. params.separate_with_bos = true;
  925. params.sample_random_offsets = false;
  926. params.force_reshuffle = false;
  927. params.opt_past = 0;
  928. params.opt_delta = 1e-5f;
  929. params.opt_max_no_improvement = 0;
  930. params.warmup = 100;
  931. params.cos_decay_steps = 1000;
  932. params.cos_decay_restart = 1.1f;
  933. params.cos_decay_min = 0.1f;
  934. params.enable_restart = false;
  935. params.adam_n_iter = 256;
  936. params.adam_alpha = 1e-3f;
  937. params.adam_min_alpha = 0;
  938. params.adam_decay = 1e-1f;
  939. params.adam_decay_min_ndim = 2;
  940. params.adam_beta1 = 0.9f;
  941. params.adam_beta2 = 0.999f;
  942. params.adam_gclip = 1.0f;
  943. params.adam_eps_f = 0.0f;
  944. return params;
  945. }
  946. void print_common_train_usage(int /*argc*/, char ** /*argv*/, const struct train_params_common * params) {
  947. // fprintf(stderr, "usage: %s [options]\n", argv[0]);
  948. // fprintf(stderr, "\n");
  949. // fprintf(stderr, "options:\n");
  950. // fprintf(stderr, " -h, --help show this help message and exit\n");
  951. fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data);
  952. fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
  953. fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
  954. fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it);
  955. fprintf(stderr, " --fn-latest STR string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest);
  956. fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every);
  957. fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n");
  958. fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
  959. fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
  960. fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
  961. fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation);
  962. fprintf(stderr, " --sample-start STR Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str());
  963. fprintf(stderr, " --include-sample-start Include the sample start in the samples. (default off)\n");
  964. fprintf(stderr, " --escape process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
  965. fprintf(stderr, " --overlapping-samples Samples my overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n");
  966. fprintf(stderr, " --fill-with-next-samples Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n");
  967. fprintf(stderr, " --separate-with-eos When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : "");
  968. fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : "");
  969. fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : "");
  970. fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : "");
  971. fprintf(stderr, " --sample-random-offsets Use samples beginning at random offsets. Together with fill-with-next-samples this may help for training endless text generation.%s\n", params->sample_random_offsets ? " (default)" : "");
  972. fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n");
  973. fprintf(stderr, " --no-flash Don't use flash attention \n");
  974. fprintf(stderr, " --use-flash Use flash attention (default)\n");
  975. fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
  976. fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n");
  977. fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup);
  978. fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
  979. fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
  980. fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min);
  981. fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : "");
  982. fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : "");
  983. fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
  984. fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
  985. fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
  986. fprintf(stderr, " --epochs N Maximum number epochs to process. (default %d)\n", params->n_epochs);
  987. fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
  988. fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
  989. fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha);
  990. fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
  991. fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim);
  992. fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
  993. fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
  994. fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
  995. fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
  996. fprintf(stderr, "\n");
  997. }
  998. bool consume_common_train_arg(
  999. int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param
  1000. ) {
  1001. int& i = *idx;
  1002. std::string arg = argv[i];
  1003. const std::string arg_prefix = "--";
  1004. if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
  1005. std::replace(arg.begin(), arg.end(), '_', '-');
  1006. }
  1007. if (arg == "--train-data") {
  1008. if (++i >= argc) {
  1009. *invalid_param = true;
  1010. return true;
  1011. }
  1012. params->fn_train_data = argv[i];
  1013. } else if (arg == "--checkpoint-in") {
  1014. if (++i >= argc) {
  1015. *invalid_param = true;
  1016. return true;
  1017. }
  1018. params->fn_checkpoint_in = argv[i];
  1019. } else if (arg == "--checkpoint-out") {
  1020. if (++i >= argc) {
  1021. *invalid_param = true;
  1022. return true;
  1023. }
  1024. params->fn_checkpoint_out = argv[i];
  1025. } else if (arg == "--pattern-fn-it") {
  1026. if (++i >= argc) {
  1027. *invalid_param = true;
  1028. return true;
  1029. }
  1030. params->pattern_fn_it = argv[i];
  1031. } else if (arg == "--fn-latest") {
  1032. if (++i >= argc) {
  1033. *invalid_param = true;
  1034. return true;
  1035. }
  1036. params->fn_latest = argv[i];
  1037. } else if (arg == "--save-every") {
  1038. if (++i >= argc) {
  1039. *invalid_param = true;
  1040. return true;
  1041. }
  1042. params->save_every = std::stoi(argv[i]);
  1043. } else if (arg == "-s" || arg == "--seed") {
  1044. if (++i >= argc) {
  1045. *invalid_param = true;
  1046. return true;
  1047. }
  1048. params->seed = std::stoi(argv[i]);
  1049. } else if (arg == "-c" || arg == "--ctx") {
  1050. if (++i >= argc) {
  1051. *invalid_param = true;
  1052. return true;
  1053. }
  1054. params->n_ctx = std::stoi(argv[i]);
  1055. params->custom_n_ctx = true;
  1056. } else if (arg == "-t" || arg == "--threads") {
  1057. if (++i >= argc) {
  1058. *invalid_param = true;
  1059. return true;
  1060. }
  1061. params->n_threads = std::stoi(argv[i]);
  1062. } else if (arg == "-b" || arg == "--batch") {
  1063. if (++i >= argc) {
  1064. *invalid_param = true;
  1065. return true;
  1066. }
  1067. params->n_batch = std::stoi(argv[i]);
  1068. } else if (arg == "--grad-acc") {
  1069. if (++i >= argc) {
  1070. *invalid_param = true;
  1071. return true;
  1072. }
  1073. params->n_gradient_accumulation = std::max(1, std::stoi(argv[i]));
  1074. } else if (arg == "--sample-start") {
  1075. if (++i >= argc) {
  1076. *invalid_param = true;
  1077. return true;
  1078. }
  1079. params->sample_start = std::string(argv[i]);
  1080. } else if (arg == "--escape") {
  1081. params->escape = true;
  1082. } else if (arg == "--include-sample-start") {
  1083. params->include_sample_start = true;
  1084. } else if (arg == "--overlapping-samples") {
  1085. params->overlapping_samples = true;
  1086. } else if (arg == "--fill-with-next-samples") {
  1087. params->fill_with_next_samples = true;
  1088. } else if (arg == "--separate-with-eos") {
  1089. params->separate_with_eos = true;
  1090. } else if (arg == "--separate-with-bos") {
  1091. params->separate_with_bos = true;
  1092. } else if (arg == "--no-separate-with-eos") {
  1093. params->separate_with_eos = false;
  1094. } else if (arg == "--no-separate-with-bos") {
  1095. params->separate_with_bos = false;
  1096. } else if (arg == "--sample-random-offsets") {
  1097. params->sample_random_offsets = true;
  1098. } else if (arg == "--force-reshuffle") {
  1099. params->force_reshuffle = true;
  1100. } else if (arg == "--no-flash") {
  1101. params->use_flash = false;
  1102. } else if (arg == "--use-flash") {
  1103. params->use_flash = true;
  1104. } else if (arg == "--no-checkpointing") {
  1105. params->use_checkpointing = false;
  1106. } else if (arg == "--use-checkpointing") {
  1107. params->use_checkpointing = true;
  1108. } else if (arg == "--warmup") {
  1109. if (++i >= argc) {
  1110. *invalid_param = true;
  1111. return true;
  1112. }
  1113. params->warmup = std::stoi(argv[i]);
  1114. } else if (arg == "--cos-decay-steps") {
  1115. if (++i >= argc) {
  1116. *invalid_param = true;
  1117. return true;
  1118. }
  1119. params->cos_decay_steps = std::stoi(argv[i]);
  1120. } else if (arg == "--cos-decay-restart") {
  1121. if (++i >= argc) {
  1122. *invalid_param = true;
  1123. return true;
  1124. }
  1125. params->cos_decay_restart = std::stof(argv[i]);
  1126. } else if (arg == "--cos-decay-min") {
  1127. if (++i >= argc) {
  1128. *invalid_param = true;
  1129. return true;
  1130. }
  1131. params->cos_decay_min = std::stof(argv[i]);
  1132. } else if (arg == "--enable-restart") {
  1133. params->enable_restart = true;
  1134. } else if (arg == "--disable-restart") {
  1135. params->enable_restart = false;
  1136. } else if (arg == "--opt-past") {
  1137. if (++i >= argc) {
  1138. *invalid_param = true;
  1139. return true;
  1140. }
  1141. params->opt_past = std::stoi(argv[i]);
  1142. } else if (arg == "--opt-delta") {
  1143. if (++i >= argc) {
  1144. *invalid_param = true;
  1145. return true;
  1146. }
  1147. params->opt_delta = std::stof(argv[i]);
  1148. } else if (arg == "--opt-max-no-improvement") {
  1149. if (++i >= argc) {
  1150. *invalid_param = true;
  1151. return true;
  1152. }
  1153. params->opt_max_no_improvement = std::stoi(argv[i]);
  1154. } else if (arg == "--adam-epsf") {
  1155. if (++i >= argc) {
  1156. *invalid_param = true;
  1157. return true;
  1158. }
  1159. params->adam_eps_f = std::stof(argv[i]);
  1160. } else if (arg == "--epochs") {
  1161. if (++i >= argc) {
  1162. *invalid_param = true;
  1163. return true;
  1164. }
  1165. params->n_epochs = std::stoi(argv[i]);
  1166. } else if (arg == "--adam-iter") {
  1167. if (++i >= argc) {
  1168. *invalid_param = true;
  1169. return true;
  1170. }
  1171. params->adam_n_iter = std::stoi(argv[i]);
  1172. } else if (arg == "--adam-alpha") {
  1173. if (++i >= argc) {
  1174. *invalid_param = true;
  1175. return true;
  1176. }
  1177. params->adam_alpha = std::stof(argv[i]);
  1178. } else if (arg == "--adam-min-alpha") {
  1179. if (++i >= argc) {
  1180. *invalid_param = true;
  1181. return true;
  1182. }
  1183. params->adam_min_alpha = std::stof(argv[i]);
  1184. } else if (arg == "--adam-decay") {
  1185. if (++i >= argc) {
  1186. *invalid_param = true;
  1187. return true;
  1188. }
  1189. params->adam_decay = std::stof(argv[i]);
  1190. } else if (arg == "--adam-decay-min-ndim") {
  1191. if (++i >= argc) {
  1192. *invalid_param = true;
  1193. return true;
  1194. }
  1195. params->adam_decay_min_ndim = std::stoi(argv[i]);
  1196. } else if (arg == "--adam-beta1") {
  1197. if (++i >= argc) {
  1198. *invalid_param = true;
  1199. return true;
  1200. }
  1201. params->adam_beta1 = std::stof(argv[i]);
  1202. } else if (arg == "--adam-beta2") {
  1203. if (++i >= argc) {
  1204. *invalid_param = true;
  1205. return true;
  1206. }
  1207. params->adam_beta2 = std::stof(argv[i]);
  1208. } else if (arg == "--adam-gclip") {
  1209. if (++i >= argc) {
  1210. *invalid_param = true;
  1211. return true;
  1212. }
  1213. params->adam_gclip = std::stof(argv[i]);
  1214. } else if (arg == "-h" || arg == "--help") {
  1215. params->print_usage = true;
  1216. return true;
  1217. } else {
  1218. return false;
  1219. }
  1220. return true;
  1221. }
  1222. void finish_processing_train_args(struct train_params_common * params) {
  1223. if (params->escape) {
  1224. process_escapes(params->sample_start);
  1225. }
  1226. }
  1227. void train_opt_callback(void * vdata, int accum_step, float * sched, bool * cancel) {
  1228. struct train_opt_callback_data * data = (struct train_opt_callback_data *) vdata;
  1229. struct train_params_common * params = data->params;
  1230. struct train_state * train = data->train;
  1231. struct ggml_opt_context * opt = train->opt;
  1232. int n_batch = params->n_batch;
  1233. int n_ctx = params->n_ctx;
  1234. if (accum_step == 0) {
  1235. // time measurement
  1236. int64_t now = ggml_time_ms();
  1237. if (now > data->last_time && opt->iter > data->first_iter) {
  1238. double dt = (double) (now - data->last_time);
  1239. if (data->millis_per_iter == 0.0) {
  1240. data->millis_per_iter = dt;
  1241. } else {
  1242. const double gain = 0.7;
  1243. data->millis_per_iter = data->millis_per_iter*(1.0-gain) + dt*gain;
  1244. }
  1245. }
  1246. double remaining_millis = 0.0;
  1247. if (data->millis_per_iter > 0.0) {
  1248. const int n_iter = params->adam_n_iter;
  1249. const int done_iter = opt->iter - data->first_iter;
  1250. const int remaining_iter = n_iter - done_iter;
  1251. remaining_millis = remaining_iter * data->millis_per_iter;
  1252. }
  1253. // file saving
  1254. const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
  1255. if (save_now) {
  1256. int new_iters = opt->iter - data->last_save_iter;
  1257. train->train_its += new_iters;
  1258. train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx;
  1259. if (data->save_cb) {
  1260. data->save_cb(data->save_data, train);
  1261. }
  1262. data->last_save_iter = opt->iter;
  1263. }
  1264. // exclude file saving from time measurement, by measuring last_time after saving
  1265. data->last_time = ggml_time_ms();
  1266. *sched = learning_schedule(
  1267. opt->iter,
  1268. params->warmup,
  1269. params->cos_decay_steps,
  1270. params->adam_alpha,
  1271. params->adam_min_alpha,
  1272. params->cos_decay_min,
  1273. params->cos_decay_restart,
  1274. params->enable_restart);
  1275. int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f);
  1276. if (impr_plot > 0) impr_plot = 0;
  1277. if (std::isnan(opt->loss_before) || std::isnan(opt->loss_after)) impr_plot = 0;
  1278. printf("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f",
  1279. __func__, opt->iter, std::min(1+train->shuffle_next_sample, train->shuffle_sample_count), train->shuffle_sample_count,
  1280. *sched, opt->loss_after);
  1281. if (data->millis_per_iter > 0) {
  1282. printf(" dt=");
  1283. print_duration(data->millis_per_iter);
  1284. printf(" eta=");
  1285. print_duration(remaining_millis);
  1286. }
  1287. float improvement = opt->loss_before - opt->loss_after;
  1288. const float plot_scale = 10.0f;
  1289. int bar_len = (int)(1 + improvement*plot_scale + 0.5);
  1290. printf(" |");
  1291. for (int i=0; i<bar_len; ++i) {
  1292. printf("-");
  1293. }
  1294. printf(">");
  1295. printf("\n");
  1296. }
  1297. int64_t used_samples = get_example_targets_batch(
  1298. data->lctx,
  1299. data->tokens_input,
  1300. data->target_probs,
  1301. train->shuffle_next_sample,
  1302. data->shuffled_samples_offs,
  1303. data->shuffled_samples_begin,
  1304. data->shuffled_samples_size,
  1305. data->samples_count,
  1306. data->tokens_data,
  1307. data->tokens_size,
  1308. params->separate_with_eos,
  1309. params->separate_with_bos,
  1310. params->fill_with_next_samples,
  1311. params->sample_random_offsets);
  1312. train->train_samples += used_samples;
  1313. train->shuffle_next_sample += used_samples;
  1314. if (train->shuffle_next_sample >= train->shuffle_sample_count) {
  1315. ++train->train_epochs;
  1316. printf("%s: reshuffle samples. completed epochs: %llu\n", __func__, (long long unsigned) train->train_epochs);
  1317. // note: we may have used some samples from the current shuffling more than once
  1318. train->shuffle_rng_state_current = train->shuffle_rng_state_next;
  1319. train->shuffle_rng_state_next = shuffle_samples(
  1320. train->shuffle_rng_state_current,
  1321. data->shuffled_samples_offs,
  1322. data->shuffled_samples_begin,
  1323. data->shuffled_samples_size,
  1324. data->samples_begin,
  1325. data->samples_size,
  1326. data->samples_count);
  1327. train->shuffle_next_sample = 0;
  1328. }
  1329. const bool last_epoch_reached = (params->n_epochs > 0 && (int64_t) train->train_epochs - data->first_epoch >= params->n_epochs);
  1330. if (last_epoch_reached) {
  1331. // allow optimization iteration at last epoch to be completed before canceling
  1332. if (data->iter_at_last_epoch < 0) {
  1333. data->iter_at_last_epoch = opt->iter;
  1334. } else if (opt->iter > data->iter_at_last_epoch) {
  1335. *cancel = true;
  1336. }
  1337. }
  1338. }