Просмотр исходного кода

llama : less KV padding when FA is off (#7257)

ggml-ci
Georgi Gerganov 1 год назад
Родитель
Сommit
614d3b914e
1 измененных файлов с 13 добавлено и 7 удалено
  1. 13 7
      llama.cpp

+ 13 - 7
llama.cpp

@@ -2805,6 +2805,11 @@ static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
     cache.do_defrag = true;
     cache.do_defrag = true;
 }
 }
 
 
+static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
+    // the FA kernels require padding to avoid extra runtime boundary checks
+    return cparams.flash_attn ? 256u : 32u;
+}
+
 //
 //
 // model loading and saving
 // model loading and saving
 //
 //
@@ -11510,7 +11515,8 @@ static int llama_decode_internal(
                 // a heuristic, to avoid attending the full cache if it is not yet utilized
                 // a heuristic, to avoid attending the full cache if it is not yet utilized
                 // after enough generations, the benefit from this heuristic disappears
                 // after enough generations, the benefit from this heuristic disappears
                 // if we start defragmenting the cache, the benefit from this will be more important
                 // if we start defragmenting the cache, the benefit from this will be more important
-                kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256)));
+                const uint32_t pad = llama_kv_cache_get_padding(cparams);
+                kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
                 //kv_self.n = llama_kv_cache_cell_max(kv_self);
                 //kv_self.n = llama_kv_cache_cell_max(kv_self);
             }
             }
         }
         }
@@ -15511,6 +15517,11 @@ struct llama_context * llama_new_context_with_model(
         return nullptr;
         return nullptr;
     }
     }
 
 
+    if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
+        LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
+        params.flash_attn = false;
+    }
+
     llama_context * ctx = new llama_context(*model);
     llama_context * ctx = new llama_context(*model);
 
 
     const auto & hparams = model->hparams;
     const auto & hparams = model->hparams;
@@ -15534,7 +15545,7 @@ struct llama_context * llama_new_context_with_model(
     cparams.rope_freq_scale  = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
     cparams.rope_freq_scale  = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
 
 
     // this is necessary due to kv_self.n being padded later during inference
     // this is necessary due to kv_self.n being padded later during inference
-    cparams.n_ctx            = GGML_PAD(cparams.n_ctx, 256);
+    cparams.n_ctx            = GGML_PAD(cparams.n_ctx, llama_kv_cache_get_padding(cparams));
 
 
     // with causal attention, the batch size is limited by the context size
     // with causal attention, the batch size is limited by the context size
     cparams.n_batch          = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
     cparams.n_batch          = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
@@ -15579,11 +15590,6 @@ struct llama_context * llama_new_context_with_model(
         }
         }
     }
     }
 
 
-    if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) {
-        LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
-        cparams.flash_attn = false;
-    }
-
     if (params.seed == LLAMA_DEFAULT_SEED) {
     if (params.seed == LLAMA_DEFAULT_SEED) {
         params.seed = time(NULL);
         params.seed = time(NULL);
     }
     }