|
|
@@ -116,11 +116,15 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-#define FATTN_VEC_CASE(D, type_K, type_V) \
|
|
|
- if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
|
|
|
- ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
|
|
|
- return; \
|
|
|
- } \
|
|
|
+#define FATTN_VEC_CASE(D, type_K, type_V) \
|
|
|
+ { \
|
|
|
+ const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
|
|
|
+ const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
|
|
|
+ if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \
|
|
|
+ ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
|
|
|
+ return; \
|
|
|
+ } \
|
|
|
+ } \
|
|
|
|
|
|
#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
|
|
|
FATTN_VEC_CASE( 64, type_K, type_V) \
|
|
|
@@ -247,6 +251,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|
|
#endif // GGML_CUDA_FA_ALL_QUANTS
|
|
|
|
|
|
switch (K->type) {
|
|
|
+ case GGML_TYPE_F32:
|
|
|
case GGML_TYPE_F16:
|
|
|
break;
|
|
|
case GGML_TYPE_Q4_1:
|
|
|
@@ -272,7 +277,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|
|
// If Turing tensor cores available, use them:
|
|
|
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
|
|
|
if (can_use_vector_kernel) {
|
|
|
- if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
|
|
|
+ if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
|
|
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
|
|
|
return BEST_FATTN_KERNEL_VEC;
|
|
|
}
|
|
|
@@ -305,7 +310,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|
|
|
|
|
// If there are no tensor cores available, use the generic tile kernel:
|
|
|
if (can_use_vector_kernel) {
|
|
|
- if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
|
|
|
+ if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
|
|
if (Q->ne[1] == 1) {
|
|
|
if (!gqa_opt_applies) {
|
|
|
return BEST_FATTN_KERNEL_VEC;
|