Просмотр исходного кода

CUDA: fix Gemma 2 numerical issues for FA (#9166)

Johannes Gäßler 1 год назад
Родитель
Сommit
f91fc5639b
1 измененных файлов с 1 добавлено и 1 удалено
  1. 1 1
      src/llama.cpp

+ 1 - 1
src/llama.cpp

@@ -8877,7 +8877,7 @@ static struct ggml_tensor * llm_build_kqv(
         cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
                                   hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
 
-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
+        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_GEMMA2) {
             ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
         }