test-state-restore-fragmented.cpp 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. // Test for state restore with fragmented KV cache
  2. // This tests the fix for: https://github.com/ggml-org/llama.cpp/issues/17527
  3. // The issue was that state restore required contiguous KV cache slots,
  4. // which fails when the cache is fragmented.
  5. //
  6. // The fix changes find_slot(ubatch, true) to find_slot(ubatch, false)
  7. // in state_read_meta(), allowing non-contiguous slot allocation.
  8. #include "arg.h"
  9. #include "common.h"
  10. #include "llama.h"
  11. #include <vector>
  12. #include <cstdio>
  13. #include <cstring>
  14. int main(int argc, char ** argv) {
  15. common_params params;
  16. params.sampling.seed = 1234;
  17. params.kv_unified = true;
  18. params.n_parallel = 3;
  19. params.n_ctx = 256;
  20. if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
  21. return 1;
  22. }
  23. common_init();
  24. // init
  25. common_init_result_ptr llama_init = common_init_from_params(params);
  26. llama_model * model = llama_init->model();
  27. llama_context * ctx = llama_init->context();
  28. if (model == nullptr || ctx == nullptr) {
  29. fprintf(stderr, "%s : failed to init\n", __func__);
  30. return 1;
  31. }
  32. GGML_UNUSED(model);
  33. // tokenize prompt
  34. std::vector<llama_token> tokens(70, 1);
  35. // interleave the 3 sequences:
  36. // 01201230123...
  37. llama_batch batch = llama_batch_init(params.n_parallel*tokens.size(), 0, 1);
  38. for (size_t i = 0; i < tokens.size(); i++) {
  39. for (int s = 0; s < params.n_parallel; ++s) {
  40. common_batch_add(batch, tokens[i], i, {s}, false);
  41. }
  42. }
  43. batch.logits[batch.n_tokens - 1] = true;
  44. if (llama_decode(ctx, batch)) {
  45. fprintf(stderr, "%s : failed to decode seq 0\n", __func__);
  46. return 1;
  47. }
  48. fprintf(stderr, "%s : processed prompt on seq 0, 1, 2 (%zu tokens each)\n", __func__, tokens.size());
  49. // Save state of seq 1
  50. std::vector<uint8_t> seq_state(llama_state_seq_get_size(ctx, 1));
  51. const size_t ncopy = llama_state_seq_get_data(ctx, seq_state.data(), seq_state.size(), 1);
  52. if (ncopy != seq_state.size()) {
  53. fprintf(stderr, "%s : failed to save seq 1 state\n", __func__);
  54. return 1;
  55. }
  56. fprintf(stderr, "%s : saved seq 1 state, %zu bytes\n", __func__, ncopy);
  57. // clear seq 1 to create a "hole" in the KV cache (fragmentation)
  58. // 0.20.20.20.2....
  59. llama_memory_t mem = llama_get_memory(ctx);
  60. llama_memory_seq_rm(mem, 1, -1, -1);
  61. fprintf(stderr, "%s : cleared seq 1 to create fragmentation\n", __func__);
  62. // Now the cache has holes where seq 1 was
  63. // This creates fragmentation - there's no contiguous block large enough
  64. // for the seq 1 state if we only look for contiguous slots
  65. // Restore seq 1 state into seq 1 (should work with non-contiguous allocation)
  66. // We use seq 1 since it's a valid sequence ID (0 to n_parallel-1)
  67. // Before the fix, this would fail with "failed to find available cells in kv cache"
  68. const size_t nset = llama_state_seq_set_data(ctx, seq_state.data(), seq_state.size(), 1);
  69. if (nset != seq_state.size()) {
  70. fprintf(stderr, "%s : FAILED to restore seq state into fragmented cache (got %zu, expected %zu)\n",
  71. __func__, nset, seq_state.size());
  72. fprintf(stderr, "%s : This is the bug - state restore fails with fragmented KV cache\n", __func__);
  73. llama_batch_free(batch);
  74. return 1;
  75. }
  76. fprintf(stderr, "%s : restored state into seq 1, %zu bytes\n", __func__, nset);
  77. // Verify we can decode with the restored state
  78. // Generate one token to verify the restored state is usable
  79. auto sparams = llama_sampler_chain_default_params();
  80. llama_sampler * smpl = llama_sampler_chain_init(sparams);
  81. llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sampling.seed));
  82. auto next_token = llama_sampler_sample(smpl, ctx, -1);
  83. auto next_token_str = common_token_to_piece(ctx, next_token);
  84. common_batch_clear(batch);
  85. common_batch_add(batch, next_token, (int)tokens.size(), {1}, true);
  86. if (llama_decode(ctx, batch)) {
  87. fprintf(stderr, "%s : failed to decode with restored state\n", __func__);
  88. llama_sampler_free(smpl);
  89. llama_batch_free(batch);
  90. return 1;
  91. }
  92. fprintf(stderr, "%s : successfully decoded with restored state, generated: '%s'\n", __func__, next_token_str.c_str());
  93. fprintf(stderr, "%s : SUCCESS - state restore works with fragmented KV cache\n", __func__);
  94. llama_sampler_free(smpl);
  95. llama_batch_free(batch);
  96. return 0;
  97. }