|
|
@@ -392,7 +392,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
|
|
|
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles,
|
|
|
+ bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
|
|
|
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
const float2 * const __restrict__ Q_f2,
|
|
|
const half2 * const __restrict__ K_h2,
|
|
|
@@ -922,7 +923,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
}
|
|
|
|
|
|
// Iterate over ne11 == previous tokens:
|
|
|
- for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
|
|
|
+ int kb0 = kb0_start;
|
|
|
+ for (; kb0 < kb0_stop-1; ++kb0) {
|
|
|
constexpr bool last_iter = false;
|
|
|
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
|
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
|
@@ -932,7 +934,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
constexpr bool last_iter = true;
|
|
|
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
|
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
|
- ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
|
|
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
|
|
}
|
|
|
|
|
|
// With multi-stage loading there is no __syncthreads at the end of the iter,
|
|
|
@@ -1204,6 +1206,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
const char * __restrict__ K,
|
|
|
const char * __restrict__ V,
|
|
|
const char * __restrict__ mask,
|
|
|
+ const int * __restrict__ KV_max,
|
|
|
float * __restrict__ dst,
|
|
|
float2 * __restrict__ dst_meta,
|
|
|
const float scale,
|
|
|
@@ -1280,7 +1283,11 @@ static __global__ void flash_attn_ext_f16(
|
|
|
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
|
|
|
|
|
|
const int kb0_start_kernel = kb0_start * kb_niter;
|
|
|
- const int kb0_stop_kernel = kb0_stop * kb_niter;
|
|
|
+ int kb0_stop_kernel = kb0_stop * kb_niter;
|
|
|
+
|
|
|
+ if (KV_max) {
|
|
|
+ kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
|
|
|
+ }
|
|
|
|
|
|
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
|
|
if (kb0_start == 0) {
|
|
|
@@ -1321,7 +1328,11 @@ static __global__ void flash_attn_ext_f16(
|
|
|
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
|
|
|
|
|
|
const int kb0_start_kernel = kb0_start * kb_niter;
|
|
|
- const int kb0_stop_kernel = kb0_stop * kb_niter;
|
|
|
+ int kb0_stop_kernel = kb0_stop * kb_niter;
|
|
|
+
|
|
|
+ if (KV_max) {
|
|
|
+ kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
|
|
|
+ }
|
|
|
|
|
|
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
|
|
constexpr bool needs_fixup = false;
|