Przeglądaj źródła

CUDA: fix Gemma 2 numerical issues for FA (#9166)

Johannes Gäßler 1 rok temu
rodzic
commit
f91fc5639b
1 zmienionych plików z 1 dodań i 1 usunięć
  1. 1 1
      src/llama.cpp

+ 1 - 1
src/llama.cpp

@@ -8877,7 +8877,7 @@ static struct ggml_tensor * llm_build_kqv(
         cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
                                   hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
 
-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
+        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_GEMMA2) {
             ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
         }