|
|
@@ -8,58 +8,32 @@
|
|
|
#include "fattn-wmma-f16.cuh"
|
|
|
#include "fattn.cuh"
|
|
|
|
|
|
-template <int D, int ncols2>
|
|
|
+template <int DKQ, int DV, int ncols2>
|
|
|
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
const ggml_tensor * Q = dst->src[0];
|
|
|
|
|
|
- if (Q->ne[1] <= 8/ncols2) {
|
|
|
- ggml_cuda_flash_attn_ext_mma_f16_case<D, 8/ncols2, ncols2>(ctx, dst);
|
|
|
- return;
|
|
|
+ if constexpr (ncols2 <= 8) {
|
|
|
+ if (Q->ne[1] <= 8/ncols2) {
|
|
|
+ ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
|
|
|
+ return;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
if (Q->ne[1] <= 16/ncols2) {
|
|
|
- ggml_cuda_flash_attn_ext_mma_f16_case<D, 16/ncols2, ncols2>(ctx, dst);
|
|
|
+ ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
if (Q->ne[1] <= 32/ncols2) {
|
|
|
- ggml_cuda_flash_attn_ext_mma_f16_case<D, 32/ncols2, ncols2>(ctx, dst);
|
|
|
+ ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- ggml_cuda_flash_attn_ext_mma_f16_case<D, 64/ncols2, ncols2>(ctx, dst);
|
|
|
-}
|
|
|
-
|
|
|
-template <int ncols2>
|
|
|
-static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
- const ggml_tensor * Q = dst->src[0];
|
|
|
-
|
|
|
- switch (Q->ne[0]) {
|
|
|
- case 64:
|
|
|
- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst);
|
|
|
- break;
|
|
|
- case 80:
|
|
|
- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst);
|
|
|
- break;
|
|
|
- case 96:
|
|
|
- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst);
|
|
|
- break;
|
|
|
- case 112:
|
|
|
- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst);
|
|
|
- break;
|
|
|
- case 128:
|
|
|
- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
|
|
|
- break;
|
|
|
- case 256:
|
|
|
- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
|
|
|
- break;
|
|
|
- default:
|
|
|
- GGML_ABORT("fatal error");
|
|
|
- break;
|
|
|
- }
|
|
|
+ ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 64/ncols2, ncols2>(ctx, dst);
|
|
|
}
|
|
|
|
|
|
-static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
+template <int DKQ, int DV>
|
|
|
+static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
const ggml_tensor * KQV = dst;
|
|
|
const ggml_tensor * Q = dst->src[0];
|
|
|
const ggml_tensor * K = dst->src[1];
|
|
|
@@ -68,27 +42,79 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
|
|
float max_bias = 0.0f;
|
|
|
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
|
|
|
|
|
- const float use_gqa_opt = mask && max_bias == 0.0f;
|
|
|
+ const bool use_gqa_opt = mask && max_bias == 0.0f;
|
|
|
|
|
|
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
|
|
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
|
|
|
|
|
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
|
|
- ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
|
|
|
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- if (use_gqa_opt && gqa_ratio == 4) {
|
|
|
- ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst);
|
|
|
+ if (use_gqa_opt && gqa_ratio % 4 == 0) {
|
|
|
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- if (use_gqa_opt && gqa_ratio == 2) {
|
|
|
- ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst);
|
|
|
+ if (use_gqa_opt && gqa_ratio % 2 == 0) {
|
|
|
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst);
|
|
|
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
|
|
|
+}
|
|
|
+
|
|
|
+static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
+ const ggml_tensor * KQV = dst;
|
|
|
+ const ggml_tensor * Q = dst->src[0];
|
|
|
+ const ggml_tensor * K = dst->src[1];
|
|
|
+ const ggml_tensor * V = dst->src[2];
|
|
|
+ const ggml_tensor * mask = dst->src[3];
|
|
|
+
|
|
|
+ switch (Q->ne[0]) {
|
|
|
+ case 64:
|
|
|
+ GGML_ASSERT(V->ne[0] == 64);
|
|
|
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64, 64>(ctx, dst);
|
|
|
+ break;
|
|
|
+ case 80:
|
|
|
+ GGML_ASSERT(V->ne[0] == 80);
|
|
|
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80, 80>(ctx, dst);
|
|
|
+ break;
|
|
|
+ case 96:
|
|
|
+ GGML_ASSERT(V->ne[0] == 96);
|
|
|
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96, 96>(ctx, dst);
|
|
|
+ break;
|
|
|
+ case 112:
|
|
|
+ GGML_ASSERT(V->ne[0] == 112);
|
|
|
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst);
|
|
|
+ break;
|
|
|
+ case 128:
|
|
|
+ GGML_ASSERT(V->ne[0] == 128);
|
|
|
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst);
|
|
|
+ break;
|
|
|
+ case 256:
|
|
|
+ GGML_ASSERT(V->ne[0] == 256);
|
|
|
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
|
|
|
+ break;
|
|
|
+ case 576: {
|
|
|
+ // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
|
|
+ GGML_ASSERT(V->ne[0] == 512);
|
|
|
+ float max_bias = 0.0f;
|
|
|
+ memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
|
|
+
|
|
|
+ const bool use_gqa_opt = mask && max_bias == 0.0f;
|
|
|
+ GGML_ASSERT(use_gqa_opt);
|
|
|
+
|
|
|
+ GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
|
|
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
|
|
|
+ GGML_ASSERT(gqa_ratio % 16 == 0);
|
|
|
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
|
|
+ } break;
|
|
|
+ default:
|
|
|
+ GGML_ABORT("fatal error");
|
|
|
+ break;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
|
|
@@ -299,7 +325,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
|
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
|
|
|
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
|
|
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
|
|
|
- const bool can_use_vector_kernel = Q->ne[0] % (2*warp_size) == 0;
|
|
|
+ const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
|
|
|
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
|
|
if (prec == GGML_PREC_DEFAULT) {
|
|
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|