train.h 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. // Various helper functions and utilities for training
  2. #pragma once
  3. #include <string>
  4. #include <random>
  5. #include <vector>
  6. #include "ggml.h"
  7. #include "llama.h"
  8. typedef std::string mt19937_state;
  9. struct train_state {
  10. struct ggml_opt_context * opt;
  11. uint64_t train_its;
  12. uint64_t train_samples;
  13. uint64_t train_tokens;
  14. uint64_t train_epochs;
  15. size_t shuffle_samples_hash; // fn, sample_count, *zip(sample_begins, sample_sizes)
  16. mt19937_state shuffle_rng_state_current;
  17. mt19937_state shuffle_rng_state_next;
  18. size_t shuffle_sample_count;
  19. size_t shuffle_next_sample;
  20. };
  21. struct train_params_common {
  22. const char * fn_train_data;
  23. const char * fn_checkpoint_in;
  24. const char * fn_checkpoint_out;
  25. const char * pattern_fn_it;
  26. const char * fn_latest;
  27. bool print_usage;
  28. int save_every;
  29. uint32_t seed;
  30. int n_ctx;
  31. int n_threads;
  32. int n_batch;
  33. int n_gradient_accumulation;
  34. int n_epochs;
  35. int n_gpu_layers;
  36. bool custom_n_ctx;
  37. bool use_flash;
  38. bool use_checkpointing;
  39. std::string sample_start;
  40. bool include_sample_start;
  41. bool escape;
  42. bool overlapping_samples;
  43. bool fill_with_next_samples;
  44. bool separate_with_eos;
  45. bool separate_with_bos;
  46. bool sample_random_offsets;
  47. bool force_reshuffle;
  48. int warmup;
  49. int cos_decay_steps;
  50. float cos_decay_restart;
  51. float cos_decay_min;
  52. bool enable_restart;
  53. int opt_past;
  54. float opt_delta;
  55. int opt_max_no_improvement;
  56. int adam_n_iter;
  57. float adam_alpha;
  58. float adam_min_alpha;
  59. float adam_decay;
  60. int adam_decay_min_ndim;
  61. float adam_beta1;
  62. float adam_beta2;
  63. float adam_gclip;
  64. float adam_eps_f;
  65. };
  66. typedef void (*save_train_files_callback)(void * data, struct train_state * train);
  67. struct train_opt_callback_data {
  68. struct train_params_common * params;
  69. struct train_state * train;
  70. save_train_files_callback save_cb;
  71. void * save_data;
  72. struct llama_context * lctx;
  73. int last_save_iter;
  74. llama_token * tokens_data;
  75. size_t tokens_size;
  76. size_t * samples_begin;
  77. size_t * samples_size;
  78. size_t * shuffled_samples_offs;
  79. size_t * shuffled_samples_begin;
  80. size_t * shuffled_samples_size;
  81. size_t samples_count;
  82. struct ggml_tensor * tokens_input;
  83. struct ggml_tensor * target_probs;
  84. int first_iter;
  85. int first_epoch;
  86. int iter_at_last_epoch;
  87. int64_t last_time;
  88. double millis_per_iter;
  89. };
  90. struct train_state * init_train_state();
  91. void free_train_state(struct train_state * state);
  92. struct train_params_common get_default_train_params_common();
  93. void print_common_train_usage(int /*argc*/, char ** argv, const struct train_params_common * params);
  94. bool consume_common_train_arg(int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param);
  95. void finish_processing_train_args(struct train_params_common * params);
  96. struct random_normal_distribution;
  97. struct random_uniform_distribution;
  98. struct random_normal_distribution * init_random_normal_distribution (int seed, float mean, float std, float min, float max);
  99. struct random_uniform_distribution * init_random_uniform_distribution(int seed, float min, float max);
  100. void free_random_normal_distribution (struct random_normal_distribution * rnd);
  101. void free_random_uniform_distribution(struct random_uniform_distribution * rnd);
  102. struct ggml_tensor * randomize_tensor_normal (struct ggml_tensor * tensor, struct random_normal_distribution * rnd);
  103. struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd);
  104. // generate random float in interval [0,1)
  105. float frand();
  106. float frand_normal (struct random_normal_distribution * rnd);
  107. float frand_uniform(struct random_uniform_distribution * rnd);
  108. int clamp (const int v, const int min, const int max);
  109. float fclamp(const float v, const float min, const float max);
  110. void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0);
  111. void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1);
  112. void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2);
  113. void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3);
  114. size_t tokenize_file(
  115. struct llama_context * lctx,
  116. const char * filename,
  117. const std::string & sample_start,
  118. bool include_sample_start,
  119. bool overlapping_samples,
  120. unsigned context_length,
  121. std::vector<llama_token> & out_tokens,
  122. std::vector<size_t> & out_samples_begin,
  123. std::vector<size_t> & out_samples_size);
  124. int64_t get_example_targets_batch(
  125. struct llama_context * lctx,
  126. struct ggml_tensor * tokens_input,
  127. struct ggml_tensor * target_probs,
  128. int64_t example_id,
  129. const size_t * samples_offs,
  130. const size_t * samples_begin,
  131. const size_t * samples_size,
  132. size_t samples_count,
  133. const llama_token * train_data,
  134. size_t n_train_data,
  135. bool separate_with_eos,
  136. bool separate_with_bos,
  137. bool fill_with_next_samples,
  138. bool sample_random_offsets);
  139. void mt19937_set_state(std::mt19937& rng, const mt19937_state& rng_state);
  140. mt19937_state mt19937_get_state(const std::mt19937& rng);
  141. mt19937_state mt19937_seed_to_state(unsigned seed);
  142. mt19937_state shuffle_samples(
  143. const mt19937_state & rng_state,
  144. size_t * shuffled_offs,
  145. size_t * shuffled_begins,
  146. size_t * shuffled_sizes,
  147. const size_t * begins,
  148. const size_t * sizes,
  149. size_t count);
  150. size_t hash_combine(size_t h1, size_t h2);
  151. size_t compute_samples_hash(
  152. const char* fn,
  153. const size_t* samples_begin,
  154. const size_t* samples_size,
  155. size_t sample_count);
  156. std::string replace_str(const char * s, const char * needle, const char * replacement);
  157. void print_duration(double milliseconds);
  158. float cosine_decay(
  159. int64_t step,
  160. int64_t decay_steps,
  161. float minimum);
  162. float cosine_decay_restart(
  163. int64_t step,
  164. int64_t decay_steps,
  165. float minimum,
  166. float restart_step_mult);
  167. float learning_schedule(
  168. int64_t step,
  169. int64_t warmup_steps,
  170. int64_t decay_steps,
  171. float learning_rate,
  172. float overall_minimum,
  173. float cos_decay_minimum,
  174. float cos_decay_restart_step_mult,
  175. bool enable_restart);
  176. void copy_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name);
  177. void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt);
  178. void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt);
  179. bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train);
  180. void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train);
  181. std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration);
  182. void train_opt_callback(void * vdata, int accum_step, float * sched, bool * cancel);