|
|
@@ -1919,6 +1919,12 @@ struct sql_printer : public printer {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
+struct ctx_state {
|
|
|
+ int depth = 0; // in tokens
|
|
|
+
|
|
|
+ std::vector<uint8_t> buf; // the llama_context state buffer
|
|
|
+};
|
|
|
+
|
|
|
static bool test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
|
|
|
llama_set_n_threads(ctx, n_threads, n_threads);
|
|
|
|
|
|
@@ -2051,6 +2057,10 @@ int main(int argc, char ** argv) {
|
|
|
llama_model * lmodel = nullptr;
|
|
|
const cmd_params_instance * prev_inst = nullptr;
|
|
|
|
|
|
+ // store the llama_context state at the previous depth that we performed a test
|
|
|
+ // ref: https://github.com/ggml-org/llama.cpp/pull/16944#issuecomment-3478151721
|
|
|
+ ctx_state cstate;
|
|
|
+
|
|
|
int params_idx = 0;
|
|
|
auto params_count = params_instances.size();
|
|
|
for (const auto & inst : params_instances) {
|
|
|
@@ -2134,14 +2144,37 @@ int main(int argc, char ** argv) {
|
|
|
llama_memory_clear(llama_get_memory(ctx), false);
|
|
|
|
|
|
if (t.n_depth > 0) {
|
|
|
- if (params.progress) {
|
|
|
- fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
|
|
|
- i + 1, params.reps);
|
|
|
+ bool is_cached = t.n_depth == cstate.depth;
|
|
|
+
|
|
|
+ if (is_cached) {
|
|
|
+ // if previously we have computed at this depth, just restore the state
|
|
|
+ const size_t ret = llama_state_seq_set_data(ctx, cstate.buf.data(), cstate.buf.size(), 0);
|
|
|
+ if (ret == 0) {
|
|
|
+ // if the old state is incompatible with the current context - reprocess from scratch
|
|
|
+ is_cached = false;
|
|
|
+ }
|
|
|
}
|
|
|
- bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
|
|
|
- if (!res) {
|
|
|
- fprintf(stderr, "%s: error: failed to run depth\n", __func__);
|
|
|
- exit(1);
|
|
|
+
|
|
|
+ if (!is_cached) {
|
|
|
+ if (params.progress) {
|
|
|
+ fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
|
|
|
+ i + 1, params.reps);
|
|
|
+ }
|
|
|
+ bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
|
|
|
+ if (!res) {
|
|
|
+ fprintf(stderr, "%s: error: failed to run depth\n", __func__);
|
|
|
+ exit(1);
|
|
|
+ }
|
|
|
+
|
|
|
+ // store the context state for reuse in later runs
|
|
|
+ cstate.depth = t.n_depth;
|
|
|
+ cstate.buf.resize(llama_state_seq_get_size(ctx, 0));
|
|
|
+ llama_state_seq_get_data(ctx, cstate.buf.data(), cstate.buf.size(), 0);
|
|
|
+ } else {
|
|
|
+ if (params.progress) {
|
|
|
+ fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d (cached)\n", params_idx, params_count,
|
|
|
+ i + 1, params.reps);
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|