Browse Source

tool/ex/tests: consistently free ctx, then model (#18168)

Johannes Gäßler 1 month ago
parent
commit
147a521636

+ 2 - 0
common/common.cpp

@@ -1078,6 +1078,8 @@ struct common_init_result::impl {
     impl() = default;
     impl() = default;
     ~impl() = default;
     ~impl() = default;
 
 
+    // note: the order in which model, context, etc. are declared matters because their destructors will be called bottom-to-top
+
     llama_model_ptr   model;
     llama_model_ptr   model;
     llama_context_ptr context;
     llama_context_ptr context;
 
 

+ 16 - 17
src/llama-context.cpp

@@ -459,23 +459,22 @@ llama_context::llama_context(
 }
 }
 
 
 llama_context::~llama_context() {
 llama_context::~llama_context() {
-    // FIXME this currently results in a use-after-free bug if the model is freed before the context
-    // if (!model.hparams.no_alloc) {
-    //     for (size_t i = 0; i < backend_ptrs.size(); ++i) {
-    //         ggml_backend_t             backend = backend_ptrs[i];
-    //         ggml_backend_buffer_type_t buft    = backend_buft[i];
-
-    //         const size_t size_exp = backend_buf_exp_size[i];
-    //         const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend);
-    //         if (size_exp == size_act) {
-    //             LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n",
-    //                 __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
-    //         } else {
-    //             LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n",
-    //                 __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
-    //         }
-    //     }
-    // }
+    if (!model.hparams.no_alloc) {
+        for (size_t i = 0; i < backend_ptrs.size(); ++i) {
+            ggml_backend_t             backend = backend_ptrs[i];
+            ggml_backend_buffer_type_t buft    = backend_buft[i];
+
+            const size_t size_exp = backend_buf_exp_size[i];
+            const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend);
+            if (size_exp == size_act) {
+                LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n",
+                    __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
+            } else {
+                LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n",
+                    __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
+            }
+        }
+    }
     ggml_opt_free(opt_ctx);
     ggml_opt_free(opt_ctx);
 }
 }
 
 

+ 3 - 0
tests/test-grammar-llguidance.cpp

@@ -1196,6 +1196,9 @@ int main(int argc, const char ** argv) {
 
 
     test_sampler_chain();
     test_sampler_chain();
 
 
+    llama_free(ctx);
+    llama_model_free(model);
+
     fprintf(stdout, "All tests passed.\n");
     fprintf(stdout, "All tests passed.\n");
     return 0;
     return 0;
 }
 }

+ 1 - 1
tests/test-tokenizer-0.cpp

@@ -300,8 +300,8 @@ int main(int argc, char **argv) {
         fprintf(stderr, "%s : tokens written to '%s'\n", __func__, (fname_text + ".tokcpp").c_str());
         fprintf(stderr, "%s : tokens written to '%s'\n", __func__, (fname_text + ".tokcpp").c_str());
     }
     }
 
 
-    llama_model_free(model);
     llama_free(ctx);
     llama_free(ctx);
+    llama_model_free(model);
 
 
     llama_backend_free();
     llama_backend_free();
 
 

+ 1 - 1
tests/test-tokenizer-1-bpe.cpp

@@ -146,8 +146,8 @@ int main(int argc, char **argv) {
         }
         }
     }
     }
 
 
-    llama_model_free(model);
     llama_free(ctx);
     llama_free(ctx);
+    llama_model_free(model);
 
 
     llama_backend_free();
     llama_backend_free();
 
 

+ 1 - 1
tests/test-tokenizer-1-spm.cpp

@@ -116,8 +116,8 @@ int main(int argc, char ** argv) {
         }
         }
     }
     }
 
 
-    llama_model_free(model);
     llama_free(ctx);
     llama_free(ctx);
+    llama_model_free(model);
 
 
     llama_backend_free();
     llama_backend_free();
 
 

+ 11 - 0
tools/batched-bench/batched-bench.cpp

@@ -55,6 +55,7 @@ int main(int argc, char ** argv) {
 
 
     if (ctx == NULL) {
     if (ctx == NULL) {
         fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
         fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
+        llama_model_free(model);
         return 1;
         return 1;
     }
     }
 
 
@@ -108,6 +109,8 @@ int main(int argc, char ** argv) {
 
 
         if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
         if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
             LOG_ERR("%s: llama_decode() failed\n", __func__);
             LOG_ERR("%s: llama_decode() failed\n", __func__);
+            llama_free(ctx);
+            llama_model_free(model);
             return 1;
             return 1;
         }
         }
     }
     }
@@ -147,6 +150,8 @@ int main(int argc, char ** argv) {
 
 
                 if (!decode_helper(ctx, batch, ctx_params.n_batch, false)) {
                 if (!decode_helper(ctx, batch, ctx_params.n_batch, false)) {
                     LOG_ERR("%s: llama_decode() failed\n", __func__);
                     LOG_ERR("%s: llama_decode() failed\n", __func__);
+                    llama_free(ctx);
+                    llama_model_free(model);
                     return 1;
                     return 1;
                 }
                 }
 
 
@@ -165,6 +170,8 @@ int main(int argc, char ** argv) {
                         common_batch_add(batch, get_token_rand(), pp + 0, { 0 }, true);
                         common_batch_add(batch, get_token_rand(), pp + 0, { 0 }, true);
                         if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
                         if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
                             LOG_ERR("%s: llama_decode() failed\n", __func__);
                             LOG_ERR("%s: llama_decode() failed\n", __func__);
+                            llama_free(ctx);
+                            llama_model_free(model);
                             return 1;
                             return 1;
                         }
                         }
                         llama_memory_seq_rm(mem, 0, pp, -1);
                         llama_memory_seq_rm(mem, 0, pp, -1);
@@ -184,6 +191,8 @@ int main(int argc, char ** argv) {
 
 
                             if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
                             if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
                                 LOG_ERR("%s: llama_decode() failed\n", __func__);
                                 LOG_ERR("%s: llama_decode() failed\n", __func__);
+                                llama_free(ctx);
+                                llama_model_free(model);
                                 return 1;
                                 return 1;
                             }
                             }
                         }
                         }
@@ -200,6 +209,8 @@ int main(int argc, char ** argv) {
 
 
                         if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
                         if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
                             LOG_ERR("%s: llama_decode() failed\n", __func__);
                             LOG_ERR("%s: llama_decode() failed\n", __func__);
+                            llama_free(ctx);
+                            llama_model_free(model);
                             return 1;
                             return 1;
                         }
                         }
                     }
                     }

+ 14 - 0
tools/llama-bench/llama-bench.cpp

@@ -2102,6 +2102,8 @@ int main(int argc, char ** argv) {
         struct ggml_threadpool_params tpp = ggml_threadpool_params_default(t.n_threads);
         struct ggml_threadpool_params tpp = ggml_threadpool_params_default(t.n_threads);
         if (!parse_cpu_mask(t.cpu_mask, tpp.cpumask)) {
         if (!parse_cpu_mask(t.cpu_mask, tpp.cpumask)) {
             fprintf(stderr, "%s: failed to parse cpu-mask: %s\n", __func__, t.cpu_mask.c_str());
             fprintf(stderr, "%s: failed to parse cpu-mask: %s\n", __func__, t.cpu_mask.c_str());
+            llama_free(ctx);
+            llama_model_free(lmodel);
             exit(1);
             exit(1);
         }
         }
         tpp.strict_cpu = t.cpu_strict;
         tpp.strict_cpu = t.cpu_strict;
@@ -2111,6 +2113,8 @@ int main(int argc, char ** argv) {
         struct ggml_threadpool * threadpool = ggml_threadpool_new_fn(&tpp);
         struct ggml_threadpool * threadpool = ggml_threadpool_new_fn(&tpp);
         if (!threadpool) {
         if (!threadpool) {
             fprintf(stderr, "%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads);
             fprintf(stderr, "%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads);
+            llama_free(ctx);
+            llama_model_free(lmodel);
             exit(1);
             exit(1);
         }
         }
 
 
@@ -2126,6 +2130,8 @@ int main(int argc, char ** argv) {
                 bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
                 bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
                 if (!res) {
                 if (!res) {
                     fprintf(stderr, "%s: error: failed to run prompt warmup\n", __func__);
                     fprintf(stderr, "%s: error: failed to run prompt warmup\n", __func__);
+                    llama_free(ctx);
+                    llama_model_free(lmodel);
                     exit(1);
                     exit(1);
                 }
                 }
             }
             }
@@ -2136,6 +2142,8 @@ int main(int argc, char ** argv) {
                 bool res = test_gen(ctx, 1, t.n_threads);
                 bool res = test_gen(ctx, 1, t.n_threads);
                 if (!res) {
                 if (!res) {
                     fprintf(stderr, "%s: error: failed to run gen warmup\n", __func__);
                     fprintf(stderr, "%s: error: failed to run gen warmup\n", __func__);
+                    llama_free(ctx);
+                    llama_model_free(lmodel);
                     exit(1);
                     exit(1);
                 }
                 }
             }
             }
@@ -2164,6 +2172,8 @@ int main(int argc, char ** argv) {
                     bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
                     bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
                     if (!res) {
                     if (!res) {
                         fprintf(stderr, "%s: error: failed to run depth\n", __func__);
                         fprintf(stderr, "%s: error: failed to run depth\n", __func__);
+                        llama_free(ctx);
+                        llama_model_free(lmodel);
                         exit(1);
                         exit(1);
                     }
                     }
 
 
@@ -2189,6 +2199,8 @@ int main(int argc, char ** argv) {
                 bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
                 bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
                 if (!res) {
                 if (!res) {
                     fprintf(stderr, "%s: error: failed to run prompt\n", __func__);
                     fprintf(stderr, "%s: error: failed to run prompt\n", __func__);
+                    llama_free(ctx);
+                    llama_model_free(lmodel);
                     exit(1);
                     exit(1);
                 }
                 }
             }
             }
@@ -2200,6 +2212,8 @@ int main(int argc, char ** argv) {
                 bool res = test_gen(ctx, t.n_gen, t.n_threads);
                 bool res = test_gen(ctx, t.n_gen, t.n_threads);
                 if (!res) {
                 if (!res) {
                     fprintf(stderr, "%s: error: failed to run gen\n", __func__);
                     fprintf(stderr, "%s: error: failed to run gen\n", __func__);
+                    llama_free(ctx);
+                    llama_model_free(lmodel);
                     exit(1);
                     exit(1);
                 }
                 }
             }
             }