|
|
@@ -400,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
|
|
}
|
|
|
|
|
|
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps,
|
|
|
- bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
|
|
|
+ bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
|
|
|
typename T_A_KQ, typename T_B_KQ, typename T_C_KQ, typename T_A_VKQ, typename T_B_VKQ, typename T_C_VKQ>
|
|
|
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
const float2 * const __restrict__ Q_f2,
|
|
|
@@ -442,8 +442,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
constexpr int stride_tile_Q = DKQ/2 + 4;
|
|
|
constexpr int stride_tile_K = nbatch_K2 + 4;
|
|
|
|
|
|
- static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
|
|
- constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
|
|
+ constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
|
|
|
|
|
|
const int k_VKQ_0 = kb0 * nbatch_fa;
|
|
|
#if defined(TURING_MMA_AVAILABLE)
|
|
|
@@ -456,7 +455,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
|
|
|
if constexpr (nstages > 1) {
|
|
|
static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline");
|
|
|
- static_assert(!mla, "multi-stage loading not implemented for MLA");
|
|
|
+ static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
|
|
|
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
|
|
|
constexpr bool use_cp_async = true;
|
|
|
cp_async_wait_all();
|
|
|
@@ -471,8 +470,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ // For MLA K and V have the same data.
|
|
|
+ // Therefore, iterate over K in reverse and later re-use the data if possible.
|
|
|
#pragma unroll
|
|
|
- for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
|
|
|
+ for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) {
|
|
|
const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
|
|
|
const int k0_diff = k0_stop - k0_start;
|
|
|
|
|
|
@@ -776,6 +777,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
}
|
|
|
|
|
|
if constexpr (nstages > 1) {
|
|
|
+ static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
|
|
|
// Preload K tile for next iteration:
|
|
|
constexpr bool use_cp_async = true;
|
|
|
cp_async_wait_all();
|
|
|
@@ -791,11 +793,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
}
|
|
|
|
|
|
|
|
|
- // For MLA K and V have the same data.
|
|
|
- // Therefore, iterate over V in reverse and re-use the data if possible.
|
|
|
- static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
|
|
|
- // constexpr int reusable_cutoff = mla ? (DV - 1) - (DV - 1) % (2*nbatch_K2) : DV;
|
|
|
- constexpr int reusable_cutoff = DV; // TODO implement properly
|
|
|
#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
|
|
|
T_A_VKQ A_identity;
|
|
|
make_identity_mat(A_identity);
|
|
|
@@ -803,12 +800,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
|
|
|
// Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
|
|
|
#pragma unroll
|
|
|
- for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
|
|
|
- const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
|
|
|
- const int i0_diff = i0_stop - i0_start;
|
|
|
+ for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) {
|
|
|
+ static_assert(DV % (2*nbatch_V2) == 0, "bad loop size");
|
|
|
+ const int i0_stop = i0_start + 2*nbatch_V2;
|
|
|
+ const int i0_diff = i0_stop - i0_start;
|
|
|
|
|
|
if constexpr (nstages <= 1) {
|
|
|
- if (i0_start < reusable_cutoff) {
|
|
|
+ if (!V_is_K_view || i0_stop > 2*nbatch_K2) {
|
|
|
constexpr bool use_cp_async = nstages == 1;
|
|
|
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
|
|
|
(V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup);
|
|
|
@@ -818,7 +816,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
__syncthreads();
|
|
|
}
|
|
|
}
|
|
|
- const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
|
|
|
+ const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2;
|
|
|
|
|
|
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
|
constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
|
|
|
@@ -921,7 +919,7 @@ template<int ncols> struct mma_tile_sizes {
|
|
|
};
|
|
|
#endif // defined(TURING_MMA_AVAILABLE)
|
|
|
|
|
|
-template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
|
|
|
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup>
|
|
|
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
const float2 * const __restrict__ Q_f2,
|
|
|
const half2 * const __restrict__ K_h2,
|
|
|
@@ -975,8 +973,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
constexpr int stride_tile_Q = DKQ/2 + 4;
|
|
|
constexpr int stride_tile_K = nbatch_K2 + 4;
|
|
|
|
|
|
- static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
|
|
- constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
|
|
+ constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
|
|
|
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
|
|
|
|
|
|
extern __shared__ half2 tile_Q[];
|
|
|
@@ -1080,7 +1077,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
constexpr bool last_iter = false;
|
|
|
constexpr int k_VKQ_sup = nbatch_fa;
|
|
|
flash_attn_ext_f16_iter
|
|
|
- <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
|
|
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
|
|
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
|
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
|
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
|
@@ -1089,7 +1086,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
constexpr bool last_iter = true;
|
|
|
const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
|
|
|
flash_attn_ext_f16_iter
|
|
|
- <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
|
|
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
|
|
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
|
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
|
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
|
@@ -1100,7 +1097,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
constexpr bool last_iter = false;
|
|
|
constexpr int k_VKQ_sup = nbatch_fa;
|
|
|
flash_attn_ext_f16_iter
|
|
|
- <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
|
|
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
|
|
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
|
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
|
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
|
@@ -1109,7 +1106,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
constexpr bool last_iter = true;
|
|
|
constexpr int k_VKQ_sup = nbatch_fa;
|
|
|
flash_attn_ext_f16_iter
|
|
|
- <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
|
|
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
|
|
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
|
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
|
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
|
@@ -1457,7 +1454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
|
|
}
|
|
|
|
|
|
-template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool mla>
|
|
|
+template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view>
|
|
|
__launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2))
|
|
|
static __global__ void flash_attn_ext_f16(
|
|
|
const char * __restrict__ Q,
|
|
|
@@ -1509,8 +1506,6 @@ static __global__ void flash_attn_ext_f16(
|
|
|
}
|
|
|
#endif // defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
|
- static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
|
|
|
-
|
|
|
constexpr int ncols = ncols1 * ncols2;
|
|
|
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
|
|
constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
|
|
|
@@ -1523,7 +1518,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
const int stride_K = nb11 / sizeof(half2);
|
|
|
const int stride_mask = nb31 / sizeof(half);
|
|
|
|
|
|
- const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
|
|
|
+ const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2);
|
|
|
|
|
|
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
|
|
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
|
|
|
@@ -1553,7 +1548,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
(const half *) (mask + nb33*(sequence % ne33));
|
|
|
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
|
|
|
|
|
|
- const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
|
|
+ const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
|
|
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
|
|
|
|
|
|
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
|
|
|
@@ -1564,12 +1559,12 @@ static __global__ void flash_attn_ext_f16(
|
|
|
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
|
|
if (kb0_start == 0) {
|
|
|
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
|
|
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
|
|
|
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
|
|
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
|
|
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
|
|
} else {
|
|
|
constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
|
|
|
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
|
|
|
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
|
|
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
|
|
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
|
|
}
|
|
|
@@ -1597,7 +1592,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
(const half *) (mask + nb33*(sequence % ne33));
|
|
|
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
|
|
|
|
|
|
- const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
|
|
+ const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
|
|
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
|
|
|
|
|
|
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
|
|
|
@@ -1608,7 +1603,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
|
|
|
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
|
|
constexpr bool needs_fixup = false;
|
|
|
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
|
|
|
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
|
|
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
|
|
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
|
|
#else
|
|
|
@@ -1644,7 +1639,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
|
const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
|
|
|
const int nwarps = nthreads / WARP_SIZE;
|
|
|
|
|
|
- constexpr bool mla = DKQ == 576;
|
|
|
+ constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu
|
|
|
|
|
|
const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
|
|
|
const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
|
|
|
@@ -1669,7 +1664,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
|
fattn_kernel_t fattn_kernel;
|
|
|
if (logit_softcap == 0.0f) {
|
|
|
constexpr bool use_logit_softcap = false;
|
|
|
- fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
|
|
|
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
|
|
|
|
|
|
#if !defined(GGML_USE_MUSA)
|
|
|
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
|
@@ -1680,7 +1675,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
|
#endif // !defined(GGML_USE_MUSA)
|
|
|
} else {
|
|
|
constexpr bool use_logit_softcap = true;
|
|
|
- fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
|
|
|
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
|
|
|
|
|
|
#if !defined(GGML_USE_MUSA)
|
|
|
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|