|
@@ -310,7 +310,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
|
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
|
|
- if (cc == GGML_CUDA_CC_VOLTA) {
|
|
|
|
|
|
|
+ if (fp16_mma_available(cc) && !new_mma_available(cc)) {
|
|
|
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
|
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|