|
@@ -1657,30 +1657,30 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
|
|
|
|
|
|
|
+size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) {
|
|
|
llama_io_write_dummy io;
|
|
llama_io_write_dummy io;
|
|
|
try {
|
|
try {
|
|
|
- return state_seq_write_data(io, seq_id);
|
|
|
|
|
|
|
+ return state_seq_write_data(io, seq_id, flags);
|
|
|
} catch (const std::exception & err) {
|
|
} catch (const std::exception & err) {
|
|
|
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
|
|
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
|
|
|
return 0;
|
|
return 0;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
|
|
|
|
|
|
|
+size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) {
|
|
|
llama_io_write_buffer io(dst, size);
|
|
llama_io_write_buffer io(dst, size);
|
|
|
try {
|
|
try {
|
|
|
- return state_seq_write_data(io, seq_id);
|
|
|
|
|
|
|
+ return state_seq_write_data(io, seq_id, flags);
|
|
|
} catch (const std::exception & err) {
|
|
} catch (const std::exception & err) {
|
|
|
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
|
|
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
|
|
|
return 0;
|
|
return 0;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
|
|
|
|
|
|
|
+size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) {
|
|
|
llama_io_read_buffer io(src, size);
|
|
llama_io_read_buffer io(src, size);
|
|
|
try {
|
|
try {
|
|
|
- return state_seq_read_data(io, seq_id);
|
|
|
|
|
|
|
+ return state_seq_read_data(io, seq_id, flags);
|
|
|
} catch (const std::exception & err) {
|
|
} catch (const std::exception & err) {
|
|
|
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
|
|
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
|
|
|
return 0;
|
|
return 0;
|
|
@@ -1778,7 +1778,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file
|
|
|
{
|
|
{
|
|
|
const size_t state_size = file.size() - file.tell();
|
|
const size_t state_size = file.size() - file.tell();
|
|
|
llama_io_read_file io(&file);
|
|
llama_io_read_file io(&file);
|
|
|
- const size_t nread = state_seq_read_data(io, seq_id);
|
|
|
|
|
|
|
+ const size_t nread = state_seq_read_data(io, seq_id, 0);
|
|
|
if (!nread) {
|
|
if (!nread) {
|
|
|
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
|
|
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
|
|
|
return 0;
|
|
return 0;
|
|
@@ -1802,7 +1802,7 @@ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * file
|
|
|
|
|
|
|
|
// save the context state using stream saving
|
|
// save the context state using stream saving
|
|
|
llama_io_write_file io(&file);
|
|
llama_io_write_file io(&file);
|
|
|
- state_seq_write_data(io, seq_id);
|
|
|
|
|
|
|
+ state_seq_write_data(io, seq_id, 0);
|
|
|
|
|
|
|
|
const size_t res = file.tell();
|
|
const size_t res = file.tell();
|
|
|
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
|
|
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
|
|
@@ -1971,21 +1971,21 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|
|
return io.n_bytes();
|
|
return io.n_bytes();
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
|
|
|
|
|
|
+size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
|
|
GGML_UNUSED(seq_id);
|
|
GGML_UNUSED(seq_id);
|
|
|
|
|
|
|
|
if (memory) {
|
|
if (memory) {
|
|
|
- memory->state_write(io, seq_id);
|
|
|
|
|
|
|
+ memory->state_write(io, seq_id, flags);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
return io.n_bytes();
|
|
return io.n_bytes();
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
|
|
|
|
|
|
+size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
|
|
GGML_UNUSED(seq_id);
|
|
GGML_UNUSED(seq_id);
|
|
|
|
|
|
|
|
if (memory) {
|
|
if (memory) {
|
|
|
- memory->state_read(io, seq_id);
|
|
|
|
|
|
|
+ memory->state_read(io, seq_id, flags);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
return io.n_bytes();
|
|
return io.n_bytes();
|
|
@@ -2801,19 +2801,31 @@ bool llama_state_save_file(llama_context * ctx, const char * path_session, const
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
|
|
size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
|
|
|
- return ctx->state_seq_get_size(seq_id);
|
|
|
|
|
|
|
+ return llama_state_seq_get_size_ext(ctx, seq_id, 0);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
|
|
size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
|
|
|
|
|
+ return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
|
|
|
|
|
+ return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
|
|
|
|
+ return ctx->state_seq_get_size(seq_id, flags);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
|
|
ctx->synchronize();
|
|
ctx->synchronize();
|
|
|
|
|
|
|
|
- return ctx->state_seq_get_data(seq_id, dst, size);
|
|
|
|
|
|
|
+ return ctx->state_seq_get_data(seq_id, dst, size, flags);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
|
|
|
|
|
|
|
+size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
|
|
ctx->synchronize();
|
|
ctx->synchronize();
|
|
|
|
|
|
|
|
- return ctx->state_seq_set_data(seq_id, src, size);
|
|
|
|
|
|
|
+ return ctx->state_seq_set_data(seq_id, src, size, flags);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
|
|
size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
|