|
@@ -47,7 +47,7 @@ int main(int argc, char ** argv) {
|
|
|
// save state (rng, logits, embedding and kv_cache) to file
|
|
// save state (rng, logits, embedding and kv_cache) to file
|
|
|
{
|
|
{
|
|
|
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
|
|
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
|
|
|
- const size_t written = llama_state_get_data(ctx, state_mem.data());
|
|
|
|
|
|
|
+ const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size());
|
|
|
|
|
|
|
|
FILE *fp_write = fopen("dump_state.bin", "wb");
|
|
FILE *fp_write = fopen("dump_state.bin", "wb");
|
|
|
fwrite(state_mem.data(), 1, written, fp_write);
|
|
fwrite(state_mem.data(), 1, written, fp_write);
|
|
@@ -99,13 +99,16 @@ int main(int argc, char ** argv) {
|
|
|
|
|
|
|
|
// load state (rng, logits, embedding and kv_cache) from file
|
|
// load state (rng, logits, embedding and kv_cache) from file
|
|
|
{
|
|
{
|
|
|
- std::vector<uint8_t> state_mem(llama_state_get_size(ctx2));
|
|
|
|
|
|
|
+ std::vector<uint8_t> state_mem;
|
|
|
|
|
|
|
|
FILE * fp_read = fopen("dump_state.bin", "rb");
|
|
FILE * fp_read = fopen("dump_state.bin", "rb");
|
|
|
|
|
+ fseek(fp_read, 0, SEEK_END);
|
|
|
|
|
+ state_mem.resize(ftell(fp_read));
|
|
|
|
|
+ fseek(fp_read, 0, SEEK_SET);
|
|
|
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
|
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
|
|
fclose(fp_read);
|
|
fclose(fp_read);
|
|
|
|
|
|
|
|
- if (read != llama_state_set_data(ctx2, state_mem.data())) {
|
|
|
|
|
|
|
+ if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) {
|
|
|
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
|
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
|
|
llama_free(ctx2);
|
|
llama_free(ctx2);
|
|
|
llama_free_model(model);
|
|
llama_free_model(model);
|
|
@@ -159,13 +162,16 @@ int main(int argc, char ** argv) {
|
|
|
|
|
|
|
|
// load state (rng, logits, embedding and kv_cache) from file
|
|
// load state (rng, logits, embedding and kv_cache) from file
|
|
|
{
|
|
{
|
|
|
- std::vector<uint8_t> state_mem(llama_state_get_size(ctx3));
|
|
|
|
|
|
|
+ std::vector<uint8_t> state_mem;
|
|
|
|
|
|
|
|
FILE * fp_read = fopen("dump_state.bin", "rb");
|
|
FILE * fp_read = fopen("dump_state.bin", "rb");
|
|
|
|
|
+ fseek(fp_read, 0, SEEK_END);
|
|
|
|
|
+ state_mem.resize(ftell(fp_read));
|
|
|
|
|
+ fseek(fp_read, 0, SEEK_SET);
|
|
|
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
|
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
|
|
fclose(fp_read);
|
|
fclose(fp_read);
|
|
|
|
|
|
|
|
- if (read != llama_state_set_data(ctx3, state_mem.data())) {
|
|
|
|
|
|
|
+ if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) {
|
|
|
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
|
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
|
|
llama_free(ctx3);
|
|
llama_free(ctx3);
|
|
|
llama_free_model(model);
|
|
llama_free_model(model);
|
|
@@ -182,7 +188,7 @@ int main(int argc, char ** argv) {
|
|
|
{
|
|
{
|
|
|
// save kv of seq 0
|
|
// save kv of seq 0
|
|
|
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
|
|
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
|
|
|
- const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0);
|
|
|
|
|
|
|
+ const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0);
|
|
|
if (ncopy != seq_store.size()) {
|
|
if (ncopy != seq_store.size()) {
|
|
|
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
|
|
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
|
|
|
llama_free(ctx3);
|
|
llama_free(ctx3);
|
|
@@ -196,7 +202,7 @@ int main(int argc, char ** argv) {
|
|
|
fprintf(stderr, "%s : kv cache cleared\n", __func__);
|
|
fprintf(stderr, "%s : kv cache cleared\n", __func__);
|
|
|
|
|
|
|
|
// restore kv into seq 1
|
|
// restore kv into seq 1
|
|
|
- const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1);
|
|
|
|
|
|
|
+ const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1);
|
|
|
if (nset != seq_store.size()) {
|
|
if (nset != seq_store.size()) {
|
|
|
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
|
|
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
|
|
|
llama_free(ctx3);
|
|
llama_free(ctx3);
|