|
|
@@ -406,6 +406,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
#endif // CP_ASYNC_AVAILABLE
|
|
|
|
|
|
#else
|
|
|
+ GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
|
|
|
+ GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
|
|
|
+ GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
|
|
|
+ GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_KV);
|
|
|
+ GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
|
|
|
+ GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
|
|
|
+ GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
|
|
|
+ GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
|
|
|
+ GGML_UNUSED(kb0);
|
|
|
NO_DEVICE_CODE;
|
|
|
#endif // NEW_MMA_AVAILABLE
|
|
|
}
|
|
|
@@ -797,6 +806,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
__syncthreads();
|
|
|
}
|
|
|
#else
|
|
|
+ GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
|
|
|
+ GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
|
|
|
+ GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
|
|
|
+ GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1);
|
|
|
+ GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_KV); GGML_UNUSED(stride_mask);
|
|
|
+ GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
|
|
|
NO_DEVICE_CODE;
|
|
|
#endif // NEW_MMA_AVAILABLE
|
|
|
}
|
|
|
@@ -931,6 +946,16 @@ static __global__ void flash_attn_ext_f16(
|
|
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
|
|
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
|
|
#else
|
|
|
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
|
|
+ GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
|
|
+ GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
|
|
+ GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
|
|
|
+ GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
|
|
|
+ GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
|
|
+ GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
|
|
+ GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
|
|
|
+ GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
|
+ GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
|
NO_DEVICE_CODE;
|
|
|
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
|
|
}
|
|
|
@@ -985,38 +1010,38 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
|
extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \
|
|
|
extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \
|
|
|
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8);
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8);
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8);
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8);
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8);
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8);
|
|
|
-
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16);
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16);
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16);
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16);
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16);
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16);
|
|
|
-
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32);
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32);
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32);
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32);
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32);
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32);
|
|
|
-
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64);
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64);
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64);
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64);
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64);
|
|
|
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64);
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8)
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8)
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8)
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8)
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8)
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8)
|
|
|
+
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16)
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16)
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16)
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16)
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16)
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16)
|
|
|
+
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32)
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32)
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32)
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32)
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32)
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32)
|
|
|
+
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64)
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64)
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64)
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64)
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64)
|
|
|
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64)
|
|
|
|
|
|
// Kernels with ncols == 128 are only 4% faster due to register pressure.
|
|
|
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128);
|
|
|
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128);
|
|
|
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128);
|
|
|
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128);
|
|
|
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128);
|
|
|
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128); // Needs too much shared memory.
|
|
|
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128)
|
|
|
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128)
|
|
|
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128)
|
|
|
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128)
|
|
|
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128)
|
|
|
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128) // Needs too much shared memory.
|