train.h 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. // Various helper functions and utilities for training
  2. #pragma once
  3. #include <string>
  4. #include <random>
  5. #include <vector>
  6. #include "ggml.h"
  7. #include "llama.h"
  8. typedef std::string mt19937_state;
  9. struct train_state {
  10. struct ggml_opt_context * opt;
  11. uint64_t train_its;
  12. uint64_t train_samples;
  13. uint64_t train_tokens;
  14. uint64_t train_epochs;
  15. size_t shuffle_samples_hash; // fn, sample_count, *zip(sample_begins, sample_sizes)
  16. mt19937_state shuffle_rng_state_current;
  17. mt19937_state shuffle_rng_state_next;
  18. size_t shuffle_sample_count;
  19. size_t shuffle_next_sample;
  20. };
  21. struct train_params_common {
  22. const char * fn_train_data;
  23. const char * fn_checkpoint_in;
  24. const char * fn_checkpoint_out;
  25. const char * pattern_fn_it;
  26. const char * fn_latest;
  27. bool print_usage;
  28. int save_every;
  29. uint32_t seed;
  30. int n_ctx;
  31. int n_threads;
  32. int n_batch;
  33. int n_gradient_accumulation;
  34. int n_epochs;
  35. bool custom_n_ctx;
  36. bool use_flash;
  37. bool use_checkpointing;
  38. std::string sample_start;
  39. bool include_sample_start;
  40. bool escape;
  41. bool overlapping_samples;
  42. bool fill_with_next_samples;
  43. bool separate_with_eos;
  44. bool separate_with_bos;
  45. bool sample_random_offsets;
  46. bool force_reshuffle;
  47. int warmup;
  48. int cos_decay_steps;
  49. float cos_decay_restart;
  50. float cos_decay_min;
  51. bool enable_restart;
  52. int opt_past;
  53. float opt_delta;
  54. int opt_max_no_improvement;
  55. int adam_n_iter;
  56. float adam_alpha;
  57. float adam_min_alpha;
  58. float adam_decay;
  59. int adam_decay_min_ndim;
  60. float adam_beta1;
  61. float adam_beta2;
  62. float adam_gclip;
  63. float adam_eps_f;
  64. };
  65. typedef void (*save_train_files_callback)(void * data, struct train_state * train);
  66. struct train_opt_callback_data {
  67. struct train_params_common * params;
  68. struct train_state * train;
  69. save_train_files_callback save_cb;
  70. void * save_data;
  71. struct llama_context * lctx;
  72. int last_save_iter;
  73. llama_token * tokens_data;
  74. size_t tokens_size;
  75. size_t * samples_begin;
  76. size_t * samples_size;
  77. size_t * shuffled_samples_offs;
  78. size_t * shuffled_samples_begin;
  79. size_t * shuffled_samples_size;
  80. size_t samples_count;
  81. struct ggml_tensor * tokens_input;
  82. struct ggml_tensor * target_probs;
  83. int first_iter;
  84. int first_epoch;
  85. int iter_at_last_epoch;
  86. int64_t last_time;
  87. double millis_per_iter;
  88. };
  89. struct train_state * init_train_state();
  90. void free_train_state(struct train_state * state);
  91. struct train_params_common get_default_train_params_common();
  92. void print_common_train_usage(int /*argc*/, char ** argv, const struct train_params_common * params);
  93. bool consume_common_train_arg(int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param);
  94. void finish_processing_train_args(struct train_params_common * params);
  95. struct random_normal_distribution;
  96. struct random_uniform_distribution;
  97. struct random_normal_distribution * init_random_normal_distribution (int seed, float mean, float std, float min, float max);
  98. struct random_uniform_distribution * init_random_uniform_distribution(int seed, float min, float max);
  99. void free_random_normal_distribution (struct random_normal_distribution * rnd);
  100. void free_random_uniform_distribution(struct random_uniform_distribution * rnd);
  101. struct ggml_tensor * randomize_tensor_normal (struct ggml_tensor * tensor, struct random_normal_distribution * rnd);
  102. struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd);
  103. // generate random float in interval [0,1)
  104. float frand();
  105. float frand_normal (struct random_normal_distribution * rnd);
  106. float frand_uniform(struct random_uniform_distribution * rnd);
  107. int clamp (const int v, const int min, const int max);
  108. float fclamp(const float v, const float min, const float max);
  109. void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0);
  110. void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1);
  111. void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2);
  112. void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3);
  113. size_t tokenize_file(
  114. struct llama_context * lctx,
  115. const char * filename,
  116. const std::string & sample_start,
  117. bool include_sample_start,
  118. bool overlapping_samples,
  119. unsigned context_length,
  120. std::vector<llama_token> & out_tokens,
  121. std::vector<size_t> & out_samples_begin,
  122. std::vector<size_t> & out_samples_size);
  123. int64_t get_example_targets_batch(
  124. struct llama_context * lctx,
  125. struct ggml_tensor * tokens_input,
  126. struct ggml_tensor * target_probs,
  127. int64_t example_id,
  128. const size_t * samples_offs,
  129. const size_t * samples_begin,
  130. const size_t * samples_size,
  131. size_t samples_count,
  132. const llama_token * train_data,
  133. size_t n_train_data,
  134. bool separate_with_eos,
  135. bool separate_with_bos,
  136. bool fill_with_next_samples,
  137. bool sample_random_offsets);
  138. void mt19937_set_state(std::mt19937& rng, const mt19937_state& rng_state);
  139. mt19937_state mt19937_get_state(const std::mt19937& rng);
  140. mt19937_state mt19937_seed_to_state(unsigned seed);
  141. mt19937_state shuffle_samples(
  142. const mt19937_state & rng_state,
  143. size_t * shuffled_offs,
  144. size_t * shuffled_begins,
  145. size_t * shuffled_sizes,
  146. const size_t * begins,
  147. const size_t * sizes,
  148. size_t count);
  149. size_t hash_combine(size_t h1, size_t h2);
  150. size_t compute_samples_hash(
  151. const char* fn,
  152. const size_t* samples_begin,
  153. const size_t* samples_size,
  154. size_t sample_count);
  155. std::string replace_str(const char * s, const char * needle, const char * replacement);
  156. void print_duration(double milliseconds);
  157. float cosine_decay(
  158. int64_t step,
  159. int64_t decay_steps,
  160. float minimum);
  161. float cosine_decay_restart(
  162. int64_t step,
  163. int64_t decay_steps,
  164. float minimum,
  165. float restart_step_mult);
  166. float learning_schedule(
  167. int64_t step,
  168. int64_t warmup_steps,
  169. int64_t decay_steps,
  170. float learning_rate,
  171. float overall_minimum,
  172. float cos_decay_minimum,
  173. float cos_decay_restart_step_mult,
  174. bool enable_restart);
  175. void copy_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name);
  176. void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt);
  177. void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt);
  178. bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train);
  179. void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train);
  180. std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration);
  181. void train_opt_callback(void * vdata, int accum_step, float * sched, bool * cancel);