|
|
@@ -27,6 +27,7 @@
|
|
|
#include <thread>
|
|
|
#include <atomic>
|
|
|
#include <mutex>
|
|
|
+#include <sstream>
|
|
|
|
|
|
#define LLAMA_USE_SCRATCH
|
|
|
#define LLAMA_MAX_SCRATCH_BUFFERS 16
|
|
|
@@ -1787,7 +1788,7 @@ struct llama_context * llama_init_from_file(
|
|
|
if (params.logits_all) {
|
|
|
ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
|
|
|
} else {
|
|
|
- ctx->logits.reserve(hparams.n_ctx);
|
|
|
+ ctx->logits.reserve(hparams.n_vocab);
|
|
|
}
|
|
|
|
|
|
if (params.embedding){
|
|
|
@@ -2252,3 +2253,122 @@ const char * llama_print_system_info(void) {
|
|
|
std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx) {
|
|
|
return ctx->model.tensors_by_name;
|
|
|
}
|
|
|
+
|
|
|
+// Returns the size of the state
|
|
|
+size_t llama_get_state_size(struct llama_context * ctx) {
|
|
|
+ const size_t s_bool = sizeof(int32_t);
|
|
|
+ // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
|
|
|
+ // for reference, std::mt19937(1337) serializes to 6701 bytes.
|
|
|
+ const size_t s_rng_size = sizeof(size_t);
|
|
|
+ const size_t s_rng = 64*1024;
|
|
|
+ const size_t s_logits_capacity = sizeof(size_t);
|
|
|
+ const size_t s_logits_size = sizeof(size_t);
|
|
|
+ const size_t s_logits = ctx->logits.capacity() * sizeof(float);
|
|
|
+ const size_t s_embedding_size = sizeof(size_t);
|
|
|
+ const size_t s_embedding = ctx->embedding.size() * sizeof(float);
|
|
|
+ const size_t s_kv_size = sizeof(size_t);
|
|
|
+ const size_t s_kv_ntok = sizeof(int);
|
|
|
+ const size_t s_kv = llama_get_kv_cache_size(ctx);
|
|
|
+ const size_t s_total = (
|
|
|
+ + s_rng_size
|
|
|
+ + s_rng
|
|
|
+ + s_logits_capacity
|
|
|
+ + s_logits_size
|
|
|
+ + s_logits
|
|
|
+ + s_embedding_size
|
|
|
+ + s_embedding
|
|
|
+ + s_kv_size
|
|
|
+ + s_kv_ntok
|
|
|
+ + s_kv
|
|
|
+ );
|
|
|
+ return s_total;
|
|
|
+}
|
|
|
+
|
|
|
+// Copies the state to the specified destination address
|
|
|
+size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
|
|
|
+ std::stringstream rng_ss;
|
|
|
+ rng_ss << ctx->rng;
|
|
|
+ const size_t rng_size = rng_ss.str().size();
|
|
|
+ char rng_buf[64*1024];
|
|
|
+ memset(&rng_buf[0], 0, 64*1024);
|
|
|
+ memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
|
|
|
+ const size_t logits_capacity = ctx->logits.capacity();
|
|
|
+ const size_t logits_size = ctx->logits.size();
|
|
|
+ const size_t embedding_size = ctx->embedding.size();
|
|
|
+ const size_t kv_size = llama_get_kv_cache_size(ctx);
|
|
|
+ const int kv_ntok = llama_get_kv_cache_token_count(ctx);
|
|
|
+
|
|
|
+ uint8_t * out = dest;
|
|
|
+ memcpy(out, &rng_size, sizeof(size_t)); out += sizeof(size_t);
|
|
|
+ memcpy(out, &rng_buf[0], 64*1024); out += 64*1024;
|
|
|
+ memcpy(out, &logits_capacity, sizeof(size_t)); out += sizeof(size_t);
|
|
|
+ memcpy(out, &logits_size, sizeof(size_t)); out += sizeof(size_t);
|
|
|
+ if (logits_size) {
|
|
|
+ memcpy(out, ctx->logits.data(), logits_size * sizeof(float));
|
|
|
+ }
|
|
|
+ out += logits_capacity * sizeof(float);
|
|
|
+ memcpy(out, &embedding_size, sizeof(size_t)); out += sizeof(size_t);
|
|
|
+ if (embedding_size) {
|
|
|
+ memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); out += embedding_size * sizeof(float);
|
|
|
+ }
|
|
|
+ memcpy(out, &kv_size, sizeof(size_t)); out += sizeof(size_t);
|
|
|
+ memcpy(out, &kv_ntok, sizeof(int)); out += sizeof(int);
|
|
|
+ if (kv_size) {
|
|
|
+ memcpy(out, llama_get_kv_cache(ctx), kv_size); out += kv_size;
|
|
|
+ }
|
|
|
+ const size_t written = out - dest;
|
|
|
+ const size_t expected = llama_get_state_size(ctx);
|
|
|
+ LLAMA_ASSERT(written == expected);
|
|
|
+ return written;
|
|
|
+}
|
|
|
+
|
|
|
+// Sets the state reading from the specified source address
|
|
|
+size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
|
|
+ size_t rng_size;
|
|
|
+ char rng_buf[64*1024];
|
|
|
+ std::stringstream rng_ss;
|
|
|
+
|
|
|
+ const uint8_t * in = src;
|
|
|
+ memcpy(&rng_size, in, sizeof(size_t)); in += sizeof(size_t);
|
|
|
+ memcpy(&rng_buf[0], in, 64*1024); in += 64*1024;
|
|
|
+ rng_ss.str(std::string(&rng_buf[0], rng_size));
|
|
|
+ rng_ss >> ctx->rng;
|
|
|
+ LLAMA_ASSERT(rng_ss.fail() == false);
|
|
|
+
|
|
|
+ size_t logits_capacity;
|
|
|
+ size_t logits_size;
|
|
|
+ size_t embedding_size;
|
|
|
+ size_t kv_size;
|
|
|
+ int kv_ntok;
|
|
|
+
|
|
|
+ memcpy(&logits_capacity, in, sizeof(size_t)); in += sizeof(size_t);
|
|
|
+ memcpy(&logits_size, in, sizeof(size_t)); in += sizeof(size_t);
|
|
|
+ LLAMA_ASSERT(ctx->logits.capacity() == logits_capacity);
|
|
|
+ if (logits_size) {
|
|
|
+ ctx->logits.resize(logits_size);
|
|
|
+ memcpy(ctx->logits.data(), in, logits_size * sizeof(float));
|
|
|
+ }
|
|
|
+ in += logits_capacity * sizeof(float);
|
|
|
+ memcpy(&embedding_size, in, sizeof(size_t)); in += sizeof(size_t);
|
|
|
+ LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size);
|
|
|
+ if (embedding_size) {
|
|
|
+ memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float));
|
|
|
+ in += embedding_size * sizeof(float);
|
|
|
+ }
|
|
|
+ memcpy(&kv_size, in, sizeof(size_t)); in += sizeof(size_t);
|
|
|
+ memcpy(&kv_ntok, in, sizeof(int)); in += sizeof(int);
|
|
|
+ if (kv_size) {
|
|
|
+ LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size);
|
|
|
+ void * k_data = ctx->model.kv_self.k->data; // remember data pointers
|
|
|
+ void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
|
|
|
+ memcpy(ctx->model.kv_self.buf.addr, in, kv_size);
|
|
|
+ ctx->model.kv_self.k->data = k_data; // restore correct data pointers
|
|
|
+ ctx->model.kv_self.v->data = v_data;
|
|
|
+ in += kv_size;
|
|
|
+ }
|
|
|
+ ctx->model.kv_self.n = kv_ntok;
|
|
|
+ const size_t nread = in - src;
|
|
|
+ const size_t expected = llama_get_state_size(ctx);
|
|
|
+ LLAMA_ASSERT(nread == expected);
|
|
|
+ return nread;
|
|
|
+}
|