|
|
@@ -839,10 +839,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
const int ne1,
|
|
|
const int ne2,
|
|
|
const int ne3) {
|
|
|
-#ifndef NEW_MMA_AVAILABLE
|
|
|
- NO_DEVICE_CODE;
|
|
|
- return;
|
|
|
-#endif // NEW_MMA_AVAILABLE
|
|
|
+#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
|
|
|
|
|
// Skip unused kernel variants for faster compilation:
|
|
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
|
|
@@ -933,6 +930,9 @@ static __global__ void flash_attn_ext_f16(
|
|
|
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
|
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
|
|
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
|
|
+#else
|
|
|
+ NO_DEVICE_CODE;
|
|
|
+#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
|
|
}
|
|
|
|
|
|
template <int D, int ncols1, int ncols2>
|