Explorar el Código

CUDA: fix Gemma 2 numerical issues for FA (#9166)

Johannes Gäßler hace 1 año
padre
commit
f91fc5639b
Se han modificado 1 ficheros con 1 adiciones y 1 borrados
  1. 1 1
      src/llama.cpp

+ 1 - 1
src/llama.cpp

@@ -8877,7 +8877,7 @@ static struct ggml_tensor * llm_build_kqv(
         cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
                                   hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
 
-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
+        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_GEMMA2) {
             ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
         }