finetune.cpp 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. #include "arg.h"
  2. #include "common.h"
  3. #include "log.h"
  4. #include "llama.h"
  5. #include <cmath>
  6. #include <cstdio>
  7. #include <cstring>
  8. #include <ctime>
  9. #include <vector>
  10. #if defined(_MSC_VER)
  11. #pragma warning(disable: 4244 4267) // possible loss of data
  12. #endif
  13. int main(int argc, char ** argv) {
  14. common_params params;
  15. params.escape = false;
  16. if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
  17. return 1;
  18. }
  19. if (params.use_mmap) {
  20. LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__);
  21. params.use_mmap = false;
  22. }
  23. if (params.cache_type_k != GGML_TYPE_F32) {
  24. LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
  25. params.cache_type_k = GGML_TYPE_F32;
  26. }
  27. if (params.cache_type_v != GGML_TYPE_F32) {
  28. LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
  29. params.cache_type_v = GGML_TYPE_F32;
  30. }
  31. common_init();
  32. llama_backend_init();
  33. llama_numa_init(params.numa);
  34. // load the model and apply lora adapter, if any
  35. common_init_result llama_init = common_init_from_params(params);
  36. llama_model_ptr & model = llama_init.model;
  37. llama_context_ptr & ctx = llama_init.context;
  38. if (model == NULL) {
  39. LOG_ERR("%s: unable to load model\n", __func__);
  40. return 1;
  41. }
  42. // print system information
  43. {
  44. LOG_INF("\n");
  45. LOG_INF("%s\n", common_params_get_system_info(params).c_str());
  46. }
  47. constexpr float val_split = 0.05f;
  48. std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
  49. ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
  50. struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);
  51. optimizer_params.adamw.alpha = 1e-7f; // learning rate
  52. struct llama_opt_params lopt_params {
  53. /*n_ctx_train =*/ 0,
  54. /*param_filter =*/ llama_opt_param_filter_all,
  55. /*param_filter_ud =*/ nullptr,
  56. /*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params,
  57. /*get_opt_pars_ud =*/ &optimizer_params,
  58. };
  59. llama_opt_init(ctx.get(), model.get(), lopt_params);
  60. const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split);
  61. ggml_opt_result_t result_train = ggml_opt_result_init();
  62. ggml_opt_result_t result_eval = ggml_opt_result_init();
  63. for (int epoch = 0; epoch < 2; ++epoch) {
  64. llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
  65. ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
  66. fprintf(stderr, "\n");
  67. ggml_opt_result_reset(result_train);
  68. ggml_opt_result_reset(result_eval);
  69. }
  70. ggml_opt_result_free(result_train);
  71. ggml_opt_result_free(result_eval);
  72. llama_model_save_to_file(model.get(), "finetuned-model.gguf");
  73. llama_backend_free();
  74. return 0;
  75. }