ggml-opt.cpp 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867
  1. #include "ggml-opt.h"
  2. #include "ggml.h"
  3. #include "ggml-alloc.h"
  4. #include "ggml-backend.h"
  5. #include "ggml-impl.h"
  6. #include <algorithm>
  7. #include <cmath>
  8. #include <cstdint>
  9. #include <cinttypes>
  10. #include <map>
  11. #include <random>
  12. #include <vector>
  13. struct ggml_opt_dataset {
  14. struct ggml_context * ctx;
  15. ggml_backend_buffer_t buf;
  16. struct ggml_tensor * data;
  17. struct ggml_tensor * labels;
  18. int64_t ndata;
  19. int64_t ndata_shard;
  20. size_t nbs_data;
  21. size_t nbs_labels;
  22. std::vector<int64_t> permutation;
  23. };
  24. struct ggml_opt_context {
  25. ggml_backend_sched_t backend_sched;
  26. ggml_cgraph * allocated_graph;
  27. ggml_cgraph * allocated_graph_copy;
  28. struct ggml_context * ctx_static;
  29. struct ggml_context * ctx_static_cpu;
  30. struct ggml_context * ctx_compute;
  31. struct ggml_context * ctx_copy;
  32. ggml_backend_buffer_t buf_static;
  33. ggml_backend_buffer_t buf_static_cpu;
  34. std::mt19937 rng;
  35. struct ggml_tensor * inputs;
  36. struct ggml_tensor * outputs;
  37. struct ggml_tensor * labels;
  38. struct ggml_tensor * loss;
  39. struct ggml_tensor * pred;
  40. struct ggml_tensor * ncorrect;
  41. struct ggml_cgraph * gf;
  42. struct ggml_cgraph * gb_grad;
  43. struct ggml_cgraph * gb_opt;
  44. int64_t iter;
  45. int32_t opt_period;
  46. int32_t opt_i;
  47. bool loss_per_datapoint;
  48. ggml_opt_get_optimizer_params get_opt_pars;
  49. void * get_opt_pars_ud;
  50. struct ggml_tensor * adamw_params;
  51. };
  52. struct ggml_opt_result {
  53. int64_t ndata = 0;
  54. std::vector<float> loss;
  55. std::vector<int32_t> pred;
  56. int64_t ncorrect = 0;
  57. bool loss_per_datapoint = false;
  58. int64_t opt_period = -1;
  59. };
  60. // ====== Dataset ======
  61. ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, int64_t ndata, int64_t ndata_shard) {
  62. GGML_ASSERT(ne_datapoint > 0);
  63. GGML_ASSERT(ne_label >= 0);
  64. GGML_ASSERT(ndata > 0);
  65. GGML_ASSERT(ndata_shard > 0);
  66. ggml_opt_dataset_t result = new ggml_opt_dataset;
  67. result->ndata = ndata;
  68. result->ndata_shard = ndata_shard;
  69. {
  70. struct ggml_init_params params = {
  71. /*.mem_size =*/ 2*ggml_tensor_overhead(),
  72. /*.mem_buffer =*/ nullptr,
  73. /*.no_alloc =*/ true,
  74. };
  75. result->ctx = ggml_init(params);
  76. }
  77. result->data = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_datapoint, ndata);
  78. result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata;
  79. if (ne_label > 0) {
  80. result->labels = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_label, ndata);
  81. result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata;
  82. } else {
  83. result->labels = nullptr;
  84. result->nbs_labels = 0;
  85. }
  86. result->buf = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx, ggml_backend_cpu_buffer_type());
  87. const int64_t nshards = ndata/ndata_shard;
  88. result->permutation.resize(nshards);
  89. for (int64_t i = 0; i < nshards; ++i) {
  90. result->permutation[i] = i;
  91. }
  92. return result;
  93. }
  94. void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) {
  95. ggml_backend_buffer_free(dataset->buf);
  96. ggml_free(dataset->ctx);
  97. delete dataset;
  98. }
  99. struct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) {
  100. return dataset->data;
  101. }
  102. struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset) {
  103. return dataset->labels;
  104. }
  105. void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata) {
  106. GGML_ASSERT(idata <= dataset->ndata);
  107. if (idata < 0) {
  108. std::shuffle(dataset->permutation.begin(), dataset->permutation.end(), opt_ctx->rng);
  109. return;
  110. }
  111. GGML_ASSERT(idata % dataset->ndata_shard == 0);
  112. const int64_t ishard_max = idata / dataset->ndata_shard;
  113. std::shuffle(dataset->permutation.begin(), dataset->permutation.begin() + ishard_max, opt_ctx->rng);
  114. }
  115. void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor * data_batch, struct ggml_tensor * labels_batch, int64_t ibatch) {
  116. GGML_ASSERT( data_batch && ggml_is_contiguous(data_batch));
  117. GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch));
  118. GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
  119. const size_t nb_data_batch = ggml_nbytes(data_batch);
  120. GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
  121. const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;
  122. if (labels_batch) {
  123. const size_t nb_labels_batch = ggml_nbytes(labels_batch);
  124. GGML_ASSERT(nb_labels_batch == shards_per_batch*dataset->nbs_labels);
  125. }
  126. GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));
  127. for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
  128. const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];
  129. const char * ptr_data = (const char *) dataset->data->data + ishard*dataset->nbs_data;
  130. ggml_backend_tensor_set(data_batch, ptr_data, ishard_batch*dataset->nbs_data, dataset->nbs_data);
  131. if (!labels_batch) {
  132. continue;
  133. }
  134. const char * ptr_labels = (const char *) dataset->labels->data + ishard*dataset->nbs_labels;
  135. ggml_backend_tensor_set(labels_batch, ptr_labels, ishard_batch*dataset->nbs_labels, dataset->nbs_labels);
  136. }
  137. }
  138. // ====== Model / Context ======
  139. struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) {
  140. GGML_UNUSED(userdata);
  141. ggml_opt_optimizer_params result;
  142. result.adamw.alpha = 0.001f;
  143. result.adamw.beta1 = 0.9f;
  144. result.adamw.beta2 = 0.999f;
  145. result.adamw.eps = 1e-8f;
  146. result.adamw.wd = 0.0f;
  147. return result;
  148. }
  149. struct ggml_opt_params ggml_opt_default_params(
  150. ggml_backend_sched_t backend_sched,
  151. struct ggml_context * ctx_compute,
  152. struct ggml_tensor * inputs,
  153. struct ggml_tensor * outputs,
  154. enum ggml_opt_loss_type loss_type) {
  155. return {
  156. /*backend_sched =*/ backend_sched,
  157. /*ctx_compute =*/ ctx_compute,
  158. /*inputs =*/ inputs,
  159. /*logits =*/ outputs,
  160. /*loss_type =*/ loss_type,
  161. /*build_type =*/ GGML_OPT_BUILD_TYPE_OPT,
  162. /*opt_period =*/ 1,
  163. /*get_opt_pars =*/ ggml_opt_get_default_optimizer_params,
  164. /*get_opt_pars_ud =*/ nullptr,
  165. };
  166. }
  167. static ggml_tensor * map_tensor(std::map<ggml_tensor *, ggml_tensor *> & tensor_map, ggml_context * ctx, ggml_tensor * tensor) {
  168. if (!tensor) {
  169. return nullptr;
  170. }
  171. if (tensor_map.find(tensor) != tensor_map.end()) {
  172. return tensor_map[tensor];
  173. }
  174. ggml_tensor * new_tensor = ggml_dup_tensor(ctx, tensor);
  175. tensor_map[tensor] = new_tensor;
  176. new_tensor->op = tensor->op;
  177. for (int i = 0; i < GGML_MAX_DIMS; i++) {
  178. new_tensor->nb[i] = tensor->nb[i];
  179. }
  180. new_tensor->flags = tensor->flags;
  181. memcpy(new_tensor->op_params, tensor->op_params, sizeof(tensor->op_params));
  182. strcpy(new_tensor->name, tensor->name);
  183. new_tensor->data = tensor->data;
  184. new_tensor->buffer = tensor->buffer;
  185. new_tensor->extra = tensor->extra;
  186. new_tensor->view_offs = tensor->view_offs;
  187. new_tensor->view_src = map_tensor(tensor_map, ctx, tensor->view_src);
  188. for (int i = 0; i < GGML_MAX_SRC; i++) {
  189. new_tensor->src[i] = map_tensor(tensor_map, ctx, tensor->src[i]);
  190. }
  191. return new_tensor;
  192. }
  193. static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * graph) {
  194. std::map<ggml_tensor *, ggml_tensor *> tensor_map;
  195. ggml_cgraph * new_graph = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true);
  196. for (int i = 0; i < graph->n_leafs; i++) {
  197. ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->leafs[i]));
  198. }
  199. for (int i = 0; i < graph->n_nodes; i++) {
  200. ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->nodes[i]));
  201. }
  202. for (int i = 0; i < graph->n_nodes; ++i) {
  203. const size_t igrad_src = ggml_hash_find(&graph->visited_hash_set, graph->nodes[i]);
  204. const size_t igrad_dst = ggml_hash_find(&new_graph->visited_hash_set, new_graph->nodes[i]);
  205. graph->grads[igrad_dst] = new_graph->grads[igrad_src];
  206. graph->grad_accs[igrad_dst] = new_graph->grad_accs[igrad_src];
  207. }
  208. return new_graph;
  209. }
  210. static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph) {
  211. GGML_ASSERT(graph);
  212. if (opt_ctx->allocated_graph == graph) {
  213. return;
  214. }
  215. ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
  216. {
  217. ggml_init_params params = {
  218. /*.mem_size =*/ ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE,
  219. /*.mem_buffer =*/ nullptr,
  220. /*.no_alloc =*/ true,
  221. };
  222. ggml_free(opt_ctx->ctx_copy);
  223. opt_ctx->ctx_copy = ggml_init(params);
  224. }
  225. opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
  226. ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
  227. opt_ctx->allocated_graph = graph;
  228. }
  229. ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
  230. ggml_opt_context_t result = new struct ggml_opt_context;
  231. result->backend_sched = params.backend_sched;
  232. result->allocated_graph = nullptr;
  233. result->allocated_graph_copy = nullptr;
  234. result->ctx_compute = params.ctx_compute;
  235. result->ctx_copy = nullptr;
  236. result->inputs = params.inputs;
  237. result->outputs = params.outputs;
  238. result->iter = 1;
  239. result->opt_period = params.opt_period;
  240. result->opt_i = 0;
  241. result->get_opt_pars = params.get_opt_pars;
  242. result->get_opt_pars_ud = params.get_opt_pars_ud;
  243. GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically");
  244. GGML_ASSERT(result->opt_period >= 1);
  245. const bool accumulate = params.build_type == GGML_OPT_BUILD_TYPE_GRAD ||
  246. (params.build_type == GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1);
  247. ggml_set_input(result->inputs);
  248. ggml_set_output(result->outputs);
  249. result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
  250. ggml_build_forward_expand(result->gf, result->outputs);
  251. int n_param = 0;
  252. for (int i = 0; i < result->gf->n_nodes; ++i) {
  253. if (result->gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) {
  254. n_param++;
  255. }
  256. }
  257. {
  258. // The static context is used for:
  259. // - gradients (1 tensor per param if using gradient accumulation)
  260. // - optimizer momenta (2 tensors per param)
  261. // - labels
  262. // - loss + its gradient (up to 5 tensors)
  263. // - pred
  264. // - ncorrect (2 tensors).
  265. const size_t tensors_per_param = (accumulate ? 1 : 0) + (params.build_type == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
  266. const size_t size_meta = (tensors_per_param*n_param + 9) * ggml_tensor_overhead();
  267. struct ggml_init_params params = {
  268. /*.mem_size =*/ size_meta,
  269. /*.mem_buffer =*/ nullptr,
  270. /*.no_alloc =*/ true,
  271. };
  272. result->ctx_static = ggml_init(params);
  273. }
  274. {
  275. // The static cpu context is used for:
  276. // - optimizer parameters (1 for the entire context)
  277. const size_t size_meta = 1 * ggml_tensor_overhead();
  278. struct ggml_init_params params = {
  279. /*.mem_size =*/ size_meta,
  280. /*.mem_buffer =*/ nullptr,
  281. /*.no_alloc =*/ true,
  282. };
  283. result->ctx_static_cpu = ggml_init(params);
  284. }
  285. switch (params.loss_type) {
  286. case GGML_OPT_LOSS_TYPE_MEAN: {
  287. result->labels = nullptr;
  288. result->loss = ggml_sum(result->ctx_static, result->outputs);
  289. ggml_set_name(result->loss, "loss_sum");
  290. const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
  291. result->loss = ggml_scale(result->ctx_static, result->loss, scale);
  292. ggml_set_name(result->loss, "loss_mean");
  293. result->loss_per_datapoint = true;
  294. break;
  295. }
  296. case GGML_OPT_LOSS_TYPE_SUM: {
  297. result->labels = nullptr;
  298. result->loss = ggml_sum(result->ctx_static, result->outputs);
  299. ggml_set_name(result->loss, "loss_sum");
  300. result->loss_per_datapoint = false;
  301. break;
  302. }
  303. case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: {
  304. result->labels = ggml_dup_tensor(result->ctx_static, result->outputs);
  305. ggml_set_input(result->labels);
  306. ggml_set_name(result->labels, "labels");
  307. result->loss = ggml_cross_entropy_loss(result->ctx_static, result->outputs, result->labels);
  308. ggml_set_name(result->loss, "loss_cross_entropy");
  309. if (result->opt_period > 1) {
  310. result->loss = ggml_scale(result->ctx_static, result->loss, 1.0f / result->opt_period);
  311. ggml_set_name(result->loss, "loss_cross_entropy_scaled");
  312. }
  313. result->loss_per_datapoint = true;
  314. break;
  315. }
  316. case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: {
  317. result->labels = ggml_dup_tensor(result->ctx_static, result->outputs);
  318. ggml_set_input(result->labels);
  319. ggml_set_name(result->labels, "labels");
  320. result->loss = ggml_sub(result->ctx_static, result->outputs, result->labels);
  321. ggml_set_name(result->loss, "loss_error");
  322. result->loss = ggml_sqr(result->ctx_static, result->loss);
  323. ggml_set_name(result->loss, "loss_squared_error");
  324. result->loss = ggml_sum(result->ctx_static, result->loss);
  325. ggml_set_name(result->loss, "loss_sum_squared_error");
  326. const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
  327. result->loss = ggml_scale(result->ctx_static, result->loss, scale);
  328. ggml_set_name(result->loss, "loss_mean_squared_error");
  329. result->loss_per_datapoint = true;
  330. break;
  331. }
  332. }
  333. ggml_set_output(result->loss);
  334. ggml_set_loss(result->loss);
  335. ggml_build_forward_expand(result->gf, result->loss);
  336. result->pred = ggml_argmax(result->ctx_static, result->outputs);
  337. ggml_set_name(result->pred, "pred");
  338. ggml_set_output(result->pred);
  339. ggml_build_forward_expand(result->gf, result->pred);
  340. if (result->labels) {
  341. result->ncorrect = ggml_count_equal(result->ctx_static, result->pred, ggml_argmax(result->ctx_static, result->labels));
  342. ggml_set_name(result->ncorrect, "ncorrect");
  343. ggml_set_output(result->ncorrect);
  344. ggml_build_forward_expand(result->gf, result->ncorrect);
  345. } else {
  346. result->ncorrect = nullptr;
  347. }
  348. if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
  349. result->gb_grad = nullptr;
  350. result->gb_opt = nullptr;
  351. result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
  352. result->buf_static_cpu = nullptr;
  353. ggml_opt_alloc_graph(result, result->gf);
  354. return result;
  355. }
  356. // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
  357. result->gb_grad = ggml_graph_dup(result->ctx_compute, result->gf);
  358. ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate);
  359. if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) {
  360. result->gb_opt = nullptr;
  361. result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
  362. result->buf_static_cpu = nullptr;
  363. ggml_opt_alloc_graph(result, result->gb_grad);
  364. ggml_graph_reset(result->gb_grad);
  365. return result;
  366. }
  367. GGML_ASSERT(params.build_type == GGML_OPT_BUILD_TYPE_OPT);
  368. // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
  369. result->gb_opt = ggml_graph_dup(result->ctx_compute, result->gb_grad);
  370. result->adamw_params = ggml_new_tensor_1d(result->ctx_static_cpu, GGML_TYPE_F32, 7);
  371. ggml_set_input(result->adamw_params);
  372. ggml_set_name(result->adamw_params, "adamw_params");
  373. for (int i = result->gf->n_nodes-1; i >= 0; --i) {
  374. struct ggml_tensor * node = result->gb_opt->nodes[i];
  375. struct ggml_tensor * grad = ggml_graph_get_grad(result->gb_opt, node);
  376. if (node->flags & GGML_TENSOR_FLAG_PARAM) {
  377. struct ggml_tensor * m = ggml_dup_tensor(result->ctx_static, node);
  378. struct ggml_tensor * v = ggml_dup_tensor(result->ctx_static, node);
  379. struct ggml_tensor * opt_step = ggml_opt_step_adamw(result->ctx_compute, node, grad, m, v, result->adamw_params);
  380. ggml_build_forward_expand(result->gb_opt, opt_step);
  381. }
  382. }
  383. result->buf_static = ggml_backend_alloc_ctx_tensors(
  384. result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
  385. result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, ggml_backend_cpu_buffer_type());
  386. ggml_opt_alloc_graph(result, result->gb_opt);
  387. ggml_graph_reset(result->gb_opt);
  388. return result;
  389. }
  390. void ggml_opt_free(ggml_opt_context_t opt_ctx) {
  391. if (opt_ctx == nullptr) {
  392. return;
  393. }
  394. ggml_backend_buffer_free(opt_ctx->buf_static);
  395. ggml_backend_buffer_free(opt_ctx->buf_static_cpu);
  396. ggml_free(opt_ctx->ctx_static);
  397. ggml_free(opt_ctx->ctx_static_cpu);
  398. delete opt_ctx;
  399. }
  400. void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) {
  401. if (optimizer) {
  402. ggml_graph_reset(opt_ctx->gb_opt);
  403. opt_ctx->iter = 1;
  404. } else {
  405. ggml_graph_reset(opt_ctx->gb_grad);
  406. }
  407. }
  408. struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) {
  409. return opt_ctx->inputs;
  410. }
  411. struct ggml_tensor * ggml_opt_outputs(ggml_opt_context_t opt_ctx) {
  412. return opt_ctx->outputs;
  413. }
  414. struct ggml_tensor * ggml_opt_labels(ggml_opt_context_t opt_ctx) {
  415. return opt_ctx->labels;
  416. }
  417. struct ggml_tensor * ggml_opt_loss(ggml_opt_context_t opt_ctx) {
  418. return opt_ctx->loss;
  419. }
  420. struct ggml_tensor * ggml_opt_pred(ggml_opt_context_t opt_ctx) {
  421. return opt_ctx->pred;
  422. }
  423. struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx) {
  424. return opt_ctx->ncorrect;
  425. }
  426. struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node) {
  427. return ggml_graph_get_grad_acc(opt_ctx->gb_opt, node);
  428. }
  429. // ====== Optimization Result ======
  430. ggml_opt_result_t ggml_opt_result_init() {
  431. return new ggml_opt_result;
  432. }
  433. void ggml_opt_result_free(ggml_opt_result_t result) {
  434. delete result;
  435. }
  436. void ggml_opt_result_reset(ggml_opt_result_t result) {
  437. result->ndata = 0;
  438. result->loss.clear();
  439. result->pred.clear();
  440. result->ncorrect = 0;
  441. }
  442. void ggml_opt_result_ndata(ggml_opt_result_t result, int64_t * ndata) {
  443. *ndata = result->ndata;
  444. }
  445. void ggml_opt_result_loss(ggml_opt_result_t result, double * loss, double * unc) {
  446. const int64_t nbatches = result->loss.size(); // Number of physical batches.
  447. if (nbatches == 0) {
  448. *loss = 0.0;
  449. *unc = NAN;
  450. return;
  451. }
  452. double sum = 0.0;
  453. double sum_squared = 0.0;
  454. for (const float & loss : result->loss) {
  455. // If the loss is per datapoint it was scaled by 1.0f/opt_period for each physical batch.
  456. const float loss_scaled = result->loss_per_datapoint ? loss*result->opt_period : loss;
  457. sum += loss_scaled;
  458. sum_squared += loss_scaled*loss_scaled;
  459. }
  460. const double mean = sum/nbatches;
  461. *loss = result->loss_per_datapoint ? mean : sum;
  462. if (!unc) {
  463. return;
  464. }
  465. if (nbatches < 2) {
  466. *unc = NAN;
  467. return;
  468. }
  469. const double var_sum = sum_squared/nbatches - mean*mean; // variance without Bessel's correction, i.e. nbatches/(nbatches-1)
  470. *unc = result->loss_per_datapoint ? sqrt(var_sum / (nbatches - 1)) : sqrt(var_sum * nbatches/(nbatches - 1));
  471. }
  472. void ggml_opt_result_pred(ggml_opt_result_t result, int32_t * pred) {
  473. for (size_t i = 0; i < result->pred.size(); ++i) {
  474. pred[i] = result->pred[i];
  475. }
  476. }
  477. void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, double * unc) {
  478. *accuracy = result->ncorrect >= 0 ? double(result->ncorrect) / double(result->ndata) : NAN;
  479. if (!unc) {
  480. return;
  481. }
  482. *unc = result->ncorrect >= 0 && result->ndata >= 2 ?
  483. sqrt((*accuracy) * (1.0 - (*accuracy)) / double(result->ndata - 1)) : NAN;
  484. }
  485. // ====== Computation ======
  486. static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, ggml_opt_result * result) {
  487. if (graph != opt_ctx->gf) {
  488. struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
  489. GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
  490. GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
  491. GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
  492. GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
  493. GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
  494. GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
  495. GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
  496. GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
  497. // beta1, beta2 after applying warmup
  498. const float beta1h = 1.0f/(1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
  499. const float beta2h = 1.0f/(1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
  500. float * adamw_par_data = ggml_get_data_f32(opt_ctx->adamw_params);
  501. adamw_par_data[0] = opt_pars.adamw.alpha;
  502. adamw_par_data[1] = opt_pars.adamw.beta1;
  503. adamw_par_data[2] = opt_pars.adamw.beta2;
  504. adamw_par_data[3] = opt_pars.adamw.eps;
  505. adamw_par_data[4] = opt_pars.adamw.wd;
  506. adamw_par_data[5] = beta1h;
  507. adamw_par_data[6] = beta2h;
  508. }
  509. ggml_opt_alloc_graph(opt_ctx, graph);
  510. ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
  511. opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt;
  512. if (!result) {
  513. return;
  514. }
  515. if (result->ndata == 0) {
  516. result->loss_per_datapoint = opt_ctx->loss_per_datapoint;
  517. result->opt_period = opt_ctx->opt_period;
  518. } else {
  519. GGML_ASSERT(result->loss_per_datapoint == opt_ctx->loss_per_datapoint);
  520. GGML_ASSERT(result->opt_period == opt_ctx->opt_period);
  521. }
  522. const int64_t ndata = opt_ctx->outputs->ne[1];
  523. GGML_ASSERT(result->ndata == ndata*int64_t(result->loss.size()) && "varying batch size not supported");
  524. result->ndata += ndata;
  525. GGML_ASSERT(ggml_is_scalar(opt_ctx->loss));
  526. GGML_ASSERT(opt_ctx->loss->type == GGML_TYPE_F32);
  527. float loss;
  528. ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss));
  529. result->loss.push_back(loss);
  530. GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32);
  531. std::vector<int32_t> pred(ndata);
  532. ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred));
  533. result->pred.insert(result->pred.end(), pred.begin(), pred.end());
  534. if (!opt_ctx->labels || result->ncorrect < 0) {
  535. result->ncorrect = -1;
  536. return;
  537. }
  538. GGML_ASSERT(ggml_is_scalar(opt_ctx->ncorrect));
  539. GGML_ASSERT(opt_ctx->ncorrect->type == GGML_TYPE_I64);
  540. int64_t ncorrect;
  541. ggml_backend_tensor_get(opt_ctx->ncorrect, &ncorrect, 0, ggml_nbytes(opt_ctx->ncorrect));
  542. result->ncorrect += ncorrect;
  543. }
  544. void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
  545. ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result);
  546. }
  547. void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
  548. if (opt_ctx->opt_period == 1) {
  549. ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
  550. return;
  551. }
  552. const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
  553. if (opt_i_next == 0) {
  554. ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
  555. ggml_opt_reset(opt_ctx, /*optimizer =*/ false);
  556. } else {
  557. ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result);
  558. }
  559. opt_ctx->opt_i = opt_i_next;
  560. }
  561. // ====== High-Level Functions ======
  562. void ggml_opt_epoch(
  563. ggml_opt_context_t opt_ctx,
  564. ggml_opt_dataset_t dataset,
  565. ggml_opt_result_t result_train,
  566. ggml_opt_result_t result_eval,
  567. int64_t idata_split,
  568. ggml_opt_epoch_callback callback_train,
  569. ggml_opt_epoch_callback callback_eval) {
  570. struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx);
  571. struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
  572. struct ggml_tensor * data = ggml_opt_dataset_data(dataset);
  573. GGML_ASSERT(data->ne[0] == inputs->ne[0]);
  574. const int64_t ndata = data->ne[1];
  575. const int64_t ndata_batch = inputs->ne[1];
  576. GGML_ASSERT(data->ne[1] % inputs->ne[1] == 0);
  577. const int64_t nbatches = ndata/ndata_batch;
  578. idata_split = idata_split < 0 ? ndata : idata_split;
  579. GGML_ASSERT(idata_split % ndata_batch == 0);
  580. const int64_t ibatch_split = idata_split / ndata_batch;
  581. int64_t ibatch = 0;
  582. int64_t t_loop_start = ggml_time_us();
  583. for (; ibatch < ibatch_split; ++ibatch) {
  584. ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
  585. ggml_opt_forward_backward(opt_ctx, result_train);
  586. if (callback_train) {
  587. callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start);
  588. }
  589. }
  590. t_loop_start = ggml_time_us();
  591. for (; ibatch < nbatches; ++ibatch) {
  592. ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
  593. ggml_opt_forward(opt_ctx, result_eval);
  594. if (callback_eval) {
  595. callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start);
  596. }
  597. }
  598. }
  599. void ggml_opt_epoch_callback_progress_bar(
  600. bool train,
  601. ggml_opt_context_t opt_ctx,
  602. ggml_opt_dataset_t dataset,
  603. ggml_opt_result_t result,
  604. int64_t ibatch,
  605. int64_t ibatch_max,
  606. int64_t t_start_us) {
  607. fprintf(stderr, "%s[", train ? "train: " : "val: ");
  608. constexpr int64_t bar_length = 25;
  609. for (int64_t j = 0; j < bar_length; ++j) {
  610. const int64_t ibatch_j = ibatch_max * j/bar_length;
  611. if (ibatch_j < ibatch) {
  612. fprintf(stderr, "=");
  613. } else if (ibatch_max * (j - 1)/bar_length < ibatch) {
  614. fprintf(stderr, ">");
  615. } else {
  616. fprintf(stderr, " ");
  617. }
  618. }
  619. const int64_t batch_size = ggml_opt_inputs(opt_ctx)->ne[1];
  620. const int64_t idata = ibatch*batch_size;
  621. const int64_t idata_max = ibatch_max*batch_size;
  622. double loss;
  623. double loss_unc;
  624. ggml_opt_result_loss(result, &loss, &loss_unc);
  625. double accuracy;
  626. double accuracy_unc;
  627. ggml_opt_result_accuracy(result, &accuracy, &accuracy_unc);
  628. const int64_t t_ibatch_us = ggml_time_us() - t_start_us;
  629. int64_t t_ibatch_s = t_ibatch_us / 1000000;
  630. const int64_t t_ibatch_h = t_ibatch_s / 3600;
  631. t_ibatch_s -= t_ibatch_h * 3600;
  632. const int64_t t_ibatch_m = t_ibatch_s / 60;
  633. t_ibatch_s -= t_ibatch_m * 60;
  634. const int64_t t_eta_us = t_ibatch_us * (ibatch_max - ibatch)/ibatch;
  635. int64_t t_eta_s = t_eta_us / 1000000;
  636. const int64_t t_eta_h = t_eta_s / 3600;
  637. t_eta_s -= t_eta_h * 3600;
  638. const int64_t t_eta_m = t_eta_s / 60;
  639. t_eta_s -= t_eta_m * 60;
  640. fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, "
  641. "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r",
  642. idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc,
  643. t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s);
  644. if (ibatch == ibatch_max) {
  645. fprintf(stderr, "\n");
  646. }
  647. fflush(stderr);
  648. GGML_UNUSED(dataset);
  649. }
  650. void ggml_opt_fit(
  651. ggml_backend_sched_t backend_sched,
  652. ggml_context * ctx_compute,
  653. ggml_tensor * inputs,
  654. ggml_tensor * outputs,
  655. ggml_opt_dataset_t dataset,
  656. enum ggml_opt_loss_type loss_type,
  657. ggml_opt_get_optimizer_params get_opt_pars,
  658. int64_t nepoch,
  659. int64_t nbatch_logical,
  660. float val_split,
  661. bool silent) {
  662. ggml_time_init();
  663. const int64_t t_start_us = ggml_time_us();
  664. const int64_t ndata = ggml_opt_dataset_data(dataset)->ne[1];
  665. const int64_t nbatch_physical = inputs->ne[1];
  666. GGML_ASSERT(ndata % nbatch_logical == 0);
  667. GGML_ASSERT(nbatch_logical % nbatch_physical == 0);
  668. const int64_t opt_period = nbatch_logical / nbatch_physical;
  669. const int64_t nbatches_logical = ndata / nbatch_logical;
  670. GGML_ASSERT(val_split >= 0.0f);
  671. GGML_ASSERT(val_split < 1.0f);
  672. const int64_t ibatch_split = int64_t(((1.0f - val_split) * nbatches_logical)) * opt_period; // train <-> val split index (physical)
  673. const int64_t idata_split = ibatch_split * nbatch_physical;
  674. int64_t epoch = 1;
  675. ggml_opt_params params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type);
  676. params.opt_period = opt_period;
  677. params.get_opt_pars = get_opt_pars;
  678. params.get_opt_pars_ud = &epoch;
  679. ggml_opt_context_t opt_ctx = ggml_opt_init(params);
  680. // Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
  681. if (nbatch_logical < ndata) {
  682. ggml_opt_dataset_shuffle(opt_ctx, dataset, -1); // Shuffle all data (train + validation).
  683. }
  684. ggml_opt_result_t result_train = ggml_opt_result_init();
  685. ggml_opt_result_t result_val = ggml_opt_result_init();
  686. ggml_opt_epoch_callback epoch_callback = silent ? nullptr : ggml_opt_epoch_callback_progress_bar;
  687. for (; epoch <= nepoch; ++epoch) {
  688. if (nbatch_logical < idata_split) {
  689. ggml_opt_dataset_shuffle(opt_ctx, dataset, idata_split);
  690. }
  691. ggml_opt_result_reset(result_train);
  692. ggml_opt_result_reset(result_val);
  693. if (!silent) {
  694. fprintf(stderr, "%s: epoch %04" PRId64 "/%04" PRId64 ":\n", __func__, epoch, nepoch);
  695. }
  696. ggml_opt_epoch(opt_ctx, dataset, result_train, result_val, idata_split, epoch_callback, epoch_callback);
  697. if (!silent) {
  698. fprintf(stderr, "\n");
  699. }
  700. }
  701. if (!silent) {
  702. int64_t t_total_s = (ggml_time_us() - t_start_us) / 1000000;
  703. const int64_t t_total_h = t_total_s / 3600;
  704. t_total_s -= t_total_h * 3600;
  705. const int64_t t_total_m = t_total_s / 60;
  706. t_total_s -= t_total_m * 60;
  707. fprintf(stderr, "%s: training took %02" PRId64 ":%02" PRId64 ":%02" PRId64 "\n", __func__, t_total_h, t_total_m, t_total_s);
  708. }
  709. ggml_opt_free(opt_ctx);
  710. ggml_opt_result_free(result_train);
  711. ggml_opt_result_free(result_val);
  712. }