|
|
@@ -940,6 +940,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
const int stride_V,
|
|
|
const int stride_mask,
|
|
|
const int jt,
|
|
|
+ const int zt,
|
|
|
const int kb0_start,
|
|
|
const int kb0_stop) {
|
|
|
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
|
|
@@ -1022,7 +1023,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
const int j = jc / ncols2;
|
|
|
const int c = jc % ncols2;
|
|
|
|
|
|
- if (jt*ncols1 + j < int(ne01.z)) {
|
|
|
+ if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt*ncols2 + c < ne02)) {
|
|
|
#pragma unroll
|
|
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
|
|
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
|
|
@@ -1408,7 +1409,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
const int j_dst = jc_dst / ncols2;
|
|
|
const int c_dst = jc_dst % ncols2;
|
|
|
|
|
|
- if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) {
|
|
|
+ if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt*ncols2 + c_dst >= ne02))) {
|
|
|
continue;
|
|
|
}
|
|
|
|
|
|
@@ -1522,10 +1523,11 @@ static __global__ void flash_attn_ext_f16(
|
|
|
|
|
|
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
|
|
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
|
|
|
+ const int iter_z = (ne02 + (ncols2 - 1)) / ncols2;
|
|
|
|
|
|
// kbc == k block continuous, current index in continuous ijk space.
|
|
|
- int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
|
- const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
|
+ int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
|
|
|
+ const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
|
|
|
|
|
|
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
|
|
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
|
|
@@ -1536,9 +1538,9 @@ static __global__ void flash_attn_ext_f16(
|
|
|
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
|
|
|
|
|
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
|
|
- const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
|
|
- const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
|
|
|
- const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
|
|
|
+ const int sequence = kbc / (iter_k*iter_j*iter_z);
|
|
|
+ const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2
|
|
|
+ const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
|
|
|
|
|
|
const int head0 = zt * ncols2;
|
|
|
|
|
|
@@ -1561,12 +1563,12 @@ static __global__ void flash_attn_ext_f16(
|
|
|
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
|
|
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
|
|
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
|
|
- ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
|
|
+ ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
|
|
|
} else {
|
|
|
constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
|
|
|
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
|
|
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
|
|
- ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
|
|
+ ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
|
|
|
}
|
|
|
|
|
|
kbc += iter_k;
|
|
|
@@ -1580,9 +1582,9 @@ static __global__ void flash_attn_ext_f16(
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
|
|
- const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
|
|
|
- const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
|
|
|
+ const int sequence = kbc / (iter_k*iter_j*iter_z);
|
|
|
+ const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2
|
|
|
+ const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
|
|
|
|
|
|
const int head0 = zt * ncols2;
|
|
|
|
|
|
@@ -1605,7 +1607,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
constexpr bool needs_fixup = false;
|
|
|
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
|
|
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
|
|
- ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
|
|
+ ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
|
|
|
#else
|
|
|
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
|
|
max_bias, m0, m1, n_head_log2, logit_softcap,
|
|
|
@@ -1739,3 +1741,5 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
|
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
|
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
|
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
|
|
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);
|
|
|
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);
|