Procházet zdrojové kódy

llama : switch KQ multiplication to F32 precision by default (#10015)

ggml-ci
Georgi Gerganov před 1 rokem
rodič
revize
8841ce3f43
1 změnil soubory, kde provedl 4 přidání a 11 odebrání
  1. 4 11
      src/llama.cpp

+ 4 - 11
src/llama.cpp

@@ -9618,20 +9618,16 @@ static struct ggml_tensor * llm_build_kqv(
         cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
                                   hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
 
-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_GEMMA2) {
-            ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
-        }
+        ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
 
         cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
     } else {
         struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
         cb(kq, "kq", il);
 
-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_NEMOTRON || model.arch == LLM_ARCH_CHATGLM) {
-            // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
-            // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
-            ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
-        }
+        // note: this op tends to require high floating point range
+        //       while for some models F16 is enough, for others it is not, so we default to F32 here
+        ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
 
         if (model.arch == LLM_ARCH_GROK) {
             // need to do the following:
@@ -9640,9 +9636,6 @@ static struct ggml_tensor * llm_build_kqv(
             // kq = 30 * tanh(kq / 30)
             // before the softmax below
 
-            //try from phi2
-            //ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
-
             kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
             kq = ggml_scale(ctx, kq, 30);
         }