|
|
@@ -8,6 +8,7 @@
|
|
|
#include <cstring>
|
|
|
#include <ctime>
|
|
|
#include <cfloat>
|
|
|
+#include <chrono>
|
|
|
#include <cmath>
|
|
|
#include <numeric>
|
|
|
#include <random>
|
|
|
@@ -162,6 +163,19 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
|
|
|
cur_p->size = k;
|
|
|
}
|
|
|
|
|
|
+static uint32_t get_rng_seed(uint32_t seed) {
|
|
|
+ if (seed == LLAMA_DEFAULT_SEED) {
|
|
|
+ // use system clock if std::random_device is not a true RNG
|
|
|
+ static bool is_rd_prng = std::random_device().entropy() == 0;
|
|
|
+ if (is_rd_prng) {
|
|
|
+ return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
|
|
|
+ }
|
|
|
+ std::random_device rd;
|
|
|
+ return rd();
|
|
|
+ }
|
|
|
+ return seed;
|
|
|
+}
|
|
|
+
|
|
|
// llama_sampler API
|
|
|
|
|
|
const char * llama_sampler_name(const struct llama_sampler * smpl) {
|
|
|
@@ -387,6 +401,7 @@ struct llama_sampler * llama_sampler_init_greedy() {
|
|
|
|
|
|
struct llama_sampler_dist {
|
|
|
const uint32_t seed;
|
|
|
+ uint32_t seed_cur;
|
|
|
|
|
|
std::mt19937 rng;
|
|
|
};
|
|
|
@@ -416,7 +431,8 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample
|
|
|
|
|
|
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
|
|
|
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
|
|
- ctx->rng = std::mt19937(ctx->seed);
|
|
|
+ ctx->seed_cur = get_rng_seed(ctx->seed);
|
|
|
+ ctx->rng.seed(ctx->seed_cur);
|
|
|
}
|
|
|
|
|
|
static void llama_sampler_dist_free(struct llama_sampler * smpl) {
|
|
|
@@ -433,11 +449,13 @@ static struct llama_sampler_i llama_sampler_dist_i = {
|
|
|
};
|
|
|
|
|
|
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
|
|
+ auto seed_cur = get_rng_seed(seed);
|
|
|
return new llama_sampler {
|
|
|
/* .iface = */ &llama_sampler_dist_i,
|
|
|
/* .ctx = */ new llama_sampler_dist {
|
|
|
- /* .seed = */ seed,
|
|
|
- /* .rng = */ std::mt19937(seed),
|
|
|
+ /* .seed = */ seed,
|
|
|
+ /* .seed_cur = */ seed_cur,
|
|
|
+ /* .rng = */ std::mt19937(seed_cur),
|
|
|
},
|
|
|
};
|
|
|
}
|
|
|
@@ -1032,6 +1050,7 @@ struct llama_sampler_mirostat {
|
|
|
const int32_t n_vocab;
|
|
|
|
|
|
const uint32_t seed;
|
|
|
+ uint32_t seed_cur;
|
|
|
|
|
|
const float tau;
|
|
|
const float eta;
|
|
|
@@ -1100,7 +1119,8 @@ static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sa
|
|
|
static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
|
|
|
auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
|
|
|
ctx->mu = 2.0f*ctx->tau;
|
|
|
- ctx->rng = std::mt19937(ctx->seed);
|
|
|
+ ctx->seed_cur = get_rng_seed(ctx->seed);
|
|
|
+ ctx->rng.seed(ctx->seed_cur);
|
|
|
}
|
|
|
|
|
|
static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
|
|
|
@@ -1117,16 +1137,18 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
|
|
|
};
|
|
|
|
|
|
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
|
|
|
+ auto seed_cur = get_rng_seed(seed);
|
|
|
return new llama_sampler {
|
|
|
/* .iface = */ &llama_sampler_mirostat_i,
|
|
|
/* .ctx = */ new llama_sampler_mirostat {
|
|
|
- /* .n_vocab = */ n_vocab,
|
|
|
- /* .seed = */ seed,
|
|
|
- /* .tau = */ tau,
|
|
|
- /* .eta = */ eta,
|
|
|
- /* .m = */ m,
|
|
|
- /* .mu = */ 2.0f*tau,
|
|
|
- /* .rng = */ std::mt19937(seed),
|
|
|
+ /* .n_vocab = */ n_vocab,
|
|
|
+ /* .seed = */ seed,
|
|
|
+ /* .seed_cur = */ seed_cur,
|
|
|
+ /* .tau = */ tau,
|
|
|
+ /* .eta = */ eta,
|
|
|
+ /* .m = */ m,
|
|
|
+ /* .mu = */ 2.0f*tau,
|
|
|
+ /* .rng = */ std::mt19937(seed_cur),
|
|
|
},
|
|
|
};
|
|
|
}
|
|
|
@@ -1135,6 +1157,7 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
|
|
|
|
|
|
struct llama_sampler_mirostat_v2 {
|
|
|
const uint32_t seed;
|
|
|
+ uint32_t seed_cur;
|
|
|
|
|
|
const float tau;
|
|
|
const float eta;
|
|
|
@@ -1179,7 +1202,8 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t
|
|
|
static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
|
|
|
auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
|
|
|
ctx->mu = 2.0f*ctx->tau;
|
|
|
- ctx->rng = std::mt19937(ctx->seed);
|
|
|
+ ctx->seed_cur = get_rng_seed(ctx->seed);
|
|
|
+ ctx->rng.seed(ctx->seed_cur);
|
|
|
}
|
|
|
|
|
|
static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
|
|
|
@@ -1212,14 +1236,16 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
|
|
|
};
|
|
|
|
|
|
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
|
|
|
+ auto seed_cur = get_rng_seed(seed);
|
|
|
return new llama_sampler {
|
|
|
/* .iface = */ &llama_sampler_mirostat_v2_i,
|
|
|
/* .ctx = */ new llama_sampler_mirostat_v2 {
|
|
|
- /* .seed = */ seed,
|
|
|
- /* .tau = */ tau,
|
|
|
- /* .eta = */ eta,
|
|
|
- /* .mu = */ 2.0f*tau,
|
|
|
- /* .rng = */ std::mt19937(seed),
|
|
|
+ /* .seed = */ seed,
|
|
|
+ /* .seed_cur = */ seed_cur,
|
|
|
+ /* .tau = */ tau,
|
|
|
+ /* .eta = */ eta,
|
|
|
+ /* .mu = */ 2.0f*tau,
|
|
|
+ /* .rng = */ std::mt19937(seed_cur),
|
|
|
},
|
|
|
};
|
|
|
}
|
|
|
@@ -1505,6 +1531,8 @@ struct llama_sampler * llama_sampler_init_penalties(
|
|
|
ignore_eos = false;
|
|
|
}
|
|
|
|
|
|
+ penalty_last_n = std::max(penalty_last_n, 0);
|
|
|
+
|
|
|
return new llama_sampler {
|
|
|
/* .iface = */ &llama_sampler_penalties_i,
|
|
|
/* .ctx = */ new llama_sampler_penalties {
|
|
|
@@ -1568,6 +1596,7 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
|
|
|
const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
|
|
|
return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
|
|
|
@@ -1599,3 +1628,31 @@ struct llama_sampler * llama_sampler_init_logit_bias(
|
|
|
},
|
|
|
};
|
|
|
}
|
|
|
+
|
|
|
+// utils
|
|
|
+
|
|
|
+uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
|
|
|
+ if (smpl->iface == &llama_sampler_dist_i) {
|
|
|
+ return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (smpl->iface == &llama_sampler_mirostat_i) {
|
|
|
+ return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (smpl->iface == &llama_sampler_mirostat_v2_i) {
|
|
|
+ return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (smpl->iface == &llama_sampler_chain_i) {
|
|
|
+ const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
|
|
|
+ for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
|
|
|
+ const uint32_t seed = llama_sampler_get_seed(*it);
|
|
|
+ if (seed != LLAMA_DEFAULT_SEED) {
|
|
|
+ return seed;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return LLAMA_DEFAULT_SEED;
|
|
|
+}
|