|
|
@@ -6,7 +6,7 @@
|
|
|
// nbatch_K == number of K columns to load in parallel for KQ calculation
|
|
|
|
|
|
// TODO optimize kernel parameters for FP16 NVIDIA (P100)
|
|
|
-// TODO optimize kernel parameters for head sizes 40, 80, 96, 112
|
|
|
+// TODO optimize kernel parameters for head sizes 40, 72, 80, 96, 112
|
|
|
|
|
|
// The ROCm compiler cannot handle templating in __launch_bounds__.
|
|
|
// As a workaround, define a macro to package the kernel parameters as uint32_t:
|
|
|
@@ -32,6 +32,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
|
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 64)
|
|
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
|
|
|
|
|
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 64, 72)
|
|
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 64, 72)
|
|
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 64, 72)
|
|
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 64, 72)
|
|
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 64, 72)
|
|
|
+
|
|
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 64, 40)
|
|
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 64, 40)
|
|
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40)
|
|
|
@@ -80,6 +86,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
|
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 3, 64, 64)
|
|
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
|
|
|
|
|
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
|
|
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
|
|
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
|
|
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
|
|
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
|
|
|
+
|
|
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
|
|
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
|
|
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
|
|
|
@@ -130,6 +142,13 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
|
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
|
|
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64)
|
|
|
|
|
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
|
|
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
|
|
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
|
|
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
|
|
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
|
|
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
|
|
|
+
|
|
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
|
|
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
|
|
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
|
|
|
@@ -185,6 +204,13 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
|
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64)
|
|
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64)
|
|
|
|
|
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
|
|
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
|
|
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
|
|
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
|
|
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
|
|
|
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
|
|
|
+
|
|
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
|
|
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
|
|
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
|
|
|
@@ -723,7 +749,7 @@ static __global__ void flash_attn_tile(
|
|
|
|
|
|
if (
|
|
|
#ifdef GGML_USE_WMMA_FATTN
|
|
|
- (ncols2 != 1 && DV != 40 && DV != 512) ||
|
|
|
+ (ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) ||
|
|
|
#endif // GGML_USE_WMMA_FATTN
|
|
|
(use_logit_softcap && !(DV == 128 || DV == 256))
|
|
|
) {
|
|
|
@@ -1198,6 +1224,7 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
|
|
|
|
|
|
extern DECL_FATTN_TILE_CASE( 40, 40);
|
|
|
extern DECL_FATTN_TILE_CASE( 64, 64);
|
|
|
+extern DECL_FATTN_TILE_CASE( 72, 72);
|
|
|
extern DECL_FATTN_TILE_CASE( 80, 80);
|
|
|
extern DECL_FATTN_TILE_CASE( 96, 96);
|
|
|
extern DECL_FATTN_TILE_CASE(112, 112);
|