瀏覽代碼

llama: fix early stop in params_fit if ctx is set (#18070)

Johannes Gäßler 1 月之前
父節點
當前提交
ec98e20021
共有 1 個文件被更改,包括 7 次插入4 次删除
  1. 7 4
      src/llama.cpp

+ 7 - 4
src/llama.cpp

@@ -241,6 +241,13 @@ static void llama_params_fit_impl(
                     global_surplus += memory_reduction;
                     LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
                         __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
+                    if (global_surplus >= 0) {
+                        if (nd == 1) {
+                            LLAMA_LOG_INFO("%s: entire model can be fit by reducing context\n", __func__);
+                            return;
+                        }
+                        LLAMA_LOG_INFO("%s: entire model should be fit across devices by reducing context\n", __func__);
+                    }
                 } else {
                     LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n",
                         __func__, hp_nct, n_ctx_min);
@@ -249,10 +256,6 @@ static void llama_params_fit_impl(
                 LLAMA_LOG_INFO("%s: context size set by user to %" PRIu32 " -> no change\n", __func__, cparams->n_ctx);
             }
         }
-        if (global_surplus >= 0) {
-            LLAMA_LOG_INFO("%s: entire model can be fit across devices by reducing context\n", __func__);
-            return;
-        }
     }
 
     if (mparams->n_gpu_layers != default_mparams.n_gpu_layers) {