|
|
@@ -190,7 +190,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
|
|
|
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
|
#endif // GGML_CUDA_FA_ALL_QUANTS
|
|
|
|
|
|
- on_no_fattn_vec_case(Q->ne[0]);
|
|
|
+ GGML_ABORT("fatal error");
|
|
|
}
|
|
|
|
|
|
#define FATTN_VEC_F32_CASE(D, type_K, type_V) \
|
|
|
@@ -265,74 +265,184 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
|
|
|
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
|
#endif // GGML_CUDA_FA_ALL_QUANTS
|
|
|
|
|
|
- on_no_fattn_vec_case(Q->ne[0]);
|
|
|
+ GGML_ABORT("fatal error");
|
|
|
}
|
|
|
|
|
|
-void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
+// Best FlashAttention kernel for a specific GPU:
|
|
|
+enum best_fattn_kernel {
|
|
|
+ BEST_FATTN_KERNEL_NONE = 0,
|
|
|
+ BEST_FATTN_KERNEL_TILE_F32 = 200,
|
|
|
+ BEST_FATTN_KERNEL_TILE_F16 = 210,
|
|
|
+ BEST_FATTN_KERNEL_VEC_F32 = 100,
|
|
|
+ BEST_FATTN_KERNEL_VEC_F16 = 110,
|
|
|
+ BEST_FATTN_KERNEL_WMMA_F16 = 300,
|
|
|
+ BEST_FATTN_KERNEL_MMA_F16 = 400,
|
|
|
+};
|
|
|
+
|
|
|
+static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const ggml_tensor * dst) {
|
|
|
+#ifndef FLASH_ATTN_AVAILABLE
|
|
|
+ GGML_UNUSED(device); GGML_UNUSED(dst);
|
|
|
+ return BEST_FATTN_KERNEL_NONE;
|
|
|
+#endif// FLASH_ATTN_AVAILABLE
|
|
|
+
|
|
|
const ggml_tensor * KQV = dst;
|
|
|
const ggml_tensor * Q = dst->src[0];
|
|
|
const ggml_tensor * K = dst->src[1];
|
|
|
const ggml_tensor * V = dst->src[2];
|
|
|
const ggml_tensor * mask = dst->src[3];
|
|
|
|
|
|
- ggml_cuda_set_device(ctx.device);
|
|
|
- const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
|
|
- const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
|
|
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
|
|
|
+ GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
|
|
+
|
|
|
+ const int cc = ggml_cuda_info().devices[device].cc;
|
|
|
+ const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
|
|
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
|
|
|
|
|
-#if defined(GGML_HIP_ROCWMMA_FATTN)
|
|
|
- if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
|
|
|
- ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
|
|
- return;
|
|
|
+ switch (K->ne[0]) {
|
|
|
+ case 64:
|
|
|
+ case 128:
|
|
|
+ case 256:
|
|
|
+ if (V->ne[0] != K->ne[0]) {
|
|
|
+ return BEST_FATTN_KERNEL_NONE;
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ case 80:
|
|
|
+ case 96:
|
|
|
+ case 112:
|
|
|
+ if (V->ne[0] != K->ne[0]) {
|
|
|
+ return BEST_FATTN_KERNEL_NONE;
|
|
|
+ }
|
|
|
+ if (!fp16_mma_available(cc) && !turing_mma_available(cc)) {
|
|
|
+ return BEST_FATTN_KERNEL_NONE;
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ case 576:
|
|
|
+ if (V->ne[0] != 512) {
|
|
|
+ return BEST_FATTN_KERNEL_NONE;
|
|
|
+ }
|
|
|
+ if (!turing_mma_available(cc) || gqa_ratio % 16 != 0) {
|
|
|
+ return BEST_FATTN_KERNEL_NONE;
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ default:
|
|
|
+ return BEST_FATTN_KERNEL_NONE;
|
|
|
}
|
|
|
-#endif // defined(GGML_HIP_ROCWMMA_FATTN)
|
|
|
|
|
|
- if (!fast_fp16_available(cc)) {
|
|
|
- if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
|
|
- ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
|
|
- } else {
|
|
|
- ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
|
|
- }
|
|
|
- return;
|
|
|
+#ifndef GGML_CUDA_FA_ALL_QUANTS
|
|
|
+ if (K->type != V->type) {
|
|
|
+ return BEST_FATTN_KERNEL_NONE;
|
|
|
}
|
|
|
+#endif // GGML_CUDA_FA_ALL_QUANTS
|
|
|
|
|
|
- if (!fp16_mma_available(cc)) {
|
|
|
- if (prec == GGML_PREC_DEFAULT) {
|
|
|
- if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
|
|
- ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
|
|
- } else {
|
|
|
- ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
|
|
+ switch (K->type) {
|
|
|
+ case GGML_TYPE_F16:
|
|
|
+ break;
|
|
|
+ case GGML_TYPE_Q4_1:
|
|
|
+ case GGML_TYPE_Q5_0:
|
|
|
+ case GGML_TYPE_Q5_1:
|
|
|
+#ifndef GGML_CUDA_FA_ALL_QUANTS
|
|
|
+ return BEST_FATTN_KERNEL_NONE;
|
|
|
+#endif // GGML_CUDA_FA_ALL_QUANTS
|
|
|
+ case GGML_TYPE_Q4_0:
|
|
|
+ case GGML_TYPE_Q8_0:
|
|
|
+#ifdef GGML_CUDA_FA_ALL_QUANTS
|
|
|
+ if (K->ne[0] != 128 && K->ne[0] != 64) {
|
|
|
+ return BEST_FATTN_KERNEL_NONE;
|
|
|
}
|
|
|
- } else {
|
|
|
- if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
|
|
- ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
|
|
- } else {
|
|
|
- ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
|
|
+#else
|
|
|
+ if (K->ne[0] != 128) {
|
|
|
+ return BEST_FATTN_KERNEL_NONE;
|
|
|
}
|
|
|
- }
|
|
|
- return;
|
|
|
+#endif // GGML_CUDA_FA_ALL_QUANTS
|
|
|
+ break;
|
|
|
+ default:
|
|
|
+ return BEST_FATTN_KERNEL_NONE;
|
|
|
+ }
|
|
|
+
|
|
|
+ switch (V->type) {
|
|
|
+ case GGML_TYPE_F16:
|
|
|
+ break;
|
|
|
+ case GGML_TYPE_Q4_1:
|
|
|
+ case GGML_TYPE_Q5_0:
|
|
|
+ case GGML_TYPE_Q5_1:
|
|
|
+ case GGML_TYPE_Q4_0:
|
|
|
+ case GGML_TYPE_Q8_0:
|
|
|
+ if (K->ne[0] != 128) {
|
|
|
+ return BEST_FATTN_KERNEL_NONE;
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ default:
|
|
|
+ return BEST_FATTN_KERNEL_NONE;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (mask && mask->ne[2] != 1) {
|
|
|
+ return BEST_FATTN_KERNEL_NONE;
|
|
|
}
|
|
|
|
|
|
- const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
|
|
|
- const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
|
|
- const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192);
|
|
|
- const bool mma_faster_for_bs1 = turing_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
|
|
|
- (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
|
|
|
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
|
|
|
- if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
|
|
- if (prec == GGML_PREC_DEFAULT) {
|
|
|
- ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
|
|
- } else {
|
|
|
- ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
|
|
+
|
|
|
+ // If Turing tensor cores available, use them except for some cases with batch size 1:
|
|
|
+ if (turing_mma_available(cc)) {
|
|
|
+ const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask; // The mma-based kernels have GQA-specific optimizations
|
|
|
+ const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
|
|
+ const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (gqa_ratio > 4 && K->ne[1] >= 8192);
|
|
|
+ const bool mma_faster_for_bs1 = gqa_opt_applies && !mma_needs_data_conversion &&
|
|
|
+ (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
|
|
|
+ if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
|
|
+ if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
|
|
+ return BEST_FATTN_KERNEL_VEC_F16;
|
|
|
+ }
|
|
|
+ return BEST_FATTN_KERNEL_VEC_F32;
|
|
|
}
|
|
|
- return;
|
|
|
+ return BEST_FATTN_KERNEL_MMA_F16;
|
|
|
}
|
|
|
|
|
|
- // The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
|
|
- if (fp16_mma_available(cc) && !turing_mma_available(cc)) {
|
|
|
- ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
|
|
- return;
|
|
|
+ // Use kernels specializes for small batch sizes if possible:
|
|
|
+ if (Q->ne[1] <= 8 && can_use_vector_kernel) {
|
|
|
+ if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
|
|
+ return BEST_FATTN_KERNEL_VEC_F16;
|
|
|
+ }
|
|
|
+ return BEST_FATTN_KERNEL_VEC_F32;
|
|
|
+ }
|
|
|
+
|
|
|
+ // For large batch sizes, use the WMMA kernel if possible:
|
|
|
+ if (fp16_mma_available(cc)) {
|
|
|
+ return BEST_FATTN_KERNEL_WMMA_F16;
|
|
|
+ }
|
|
|
+
|
|
|
+ // If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes:
|
|
|
+ if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
|
|
+ return BEST_FATTN_KERNEL_TILE_F16;
|
|
|
}
|
|
|
+ return BEST_FATTN_KERNEL_TILE_F32;
|
|
|
+}
|
|
|
+
|
|
|
+void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
+ ggml_cuda_set_device(ctx.device);
|
|
|
+ switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) {
|
|
|
+ case BEST_FATTN_KERNEL_NONE:
|
|
|
+ GGML_ABORT("fatal error");
|
|
|
+ case BEST_FATTN_KERNEL_TILE_F32:
|
|
|
+ ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
|
|
+ break;
|
|
|
+ case BEST_FATTN_KERNEL_TILE_F16:
|
|
|
+ ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
|
|
+ break;
|
|
|
+ case BEST_FATTN_KERNEL_VEC_F32:
|
|
|
+ ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
|
|
+ break;
|
|
|
+ case BEST_FATTN_KERNEL_VEC_F16:
|
|
|
+ ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
|
|
+ break;
|
|
|
+ case BEST_FATTN_KERNEL_WMMA_F16:
|
|
|
+ ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
|
|
+ break;
|
|
|
+ case BEST_FATTN_KERNEL_MMA_F16:
|
|
|
+ ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
|
|
+ break;
|
|
|
+ }
|
|
|
+}
|
|
|
|
|
|
- ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
|
|
+bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst) {
|
|
|
+ return ggml_cuda_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE;
|
|
|
}
|