|
|
@@ -408,7 +408,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
const int stride_K,
|
|
|
const int stride_V,
|
|
|
const int stride_mask,
|
|
|
- const int jt,
|
|
|
half2 * const __restrict__ tile_Q,
|
|
|
half2 * const __restrict__ tile_K,
|
|
|
half2 * const __restrict__ tile_V,
|
|
|
@@ -455,7 +454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
cp_async_wait_all();
|
|
|
__syncthreads();
|
|
|
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
|
|
- (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
|
|
|
+ (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V);
|
|
|
} else {
|
|
|
constexpr bool use_cp_async = nstages == 1;
|
|
|
if (ncols2 > 1 || mask_h2) {
|
|
|
@@ -471,7 +470,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
if (nstages <= 1) {
|
|
|
constexpr bool use_cp_async = nstages == 1;
|
|
|
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
|
|
- (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
|
|
|
+ (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K);
|
|
|
if (use_cp_async) {
|
|
|
cp_async_wait_all();
|
|
|
}
|
|
|
@@ -715,7 +714,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
|
|
|
}
|
|
|
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
|
|
- (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
|
|
|
+ (K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -732,7 +731,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
if (nstages <= 1 && i0_start < reusable_cutoff) {
|
|
|
constexpr bool use_cp_async = nstages == 1;
|
|
|
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
|
|
- (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
|
|
+ (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
|
|
if (use_cp_async) {
|
|
|
cp_async_wait_all();
|
|
|
}
|
|
|
@@ -771,8 +770,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
|
|
|
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
|
|
|
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V);
|
|
|
- GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
|
|
|
- GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
|
|
|
+ GGML_UNUSED(stride_mask); GGML_UNUSED(tile_K);
|
|
|
GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
|
|
|
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
|
|
|
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
|
|
|
@@ -920,7 +918,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
|
|
|
}
|
|
|
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
|
|
- (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
|
|
|
+ (K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
|
|
|
}
|
|
|
|
|
|
// Iterate over ne11 == previous tokens:
|
|
|
@@ -928,13 +926,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
constexpr bool last_iter = false;
|
|
|
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
|
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
|
- ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
|
|
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
|
|
}
|
|
|
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
|
|
constexpr bool last_iter = true;
|
|
|
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
|
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
|
- ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
|
|
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
|
|
}
|
|
|
|
|
|
// With multi-stage loading there is no __syncthreads at the end of the iter,
|
|
|
@@ -1214,33 +1212,13 @@ static __global__ void flash_attn_ext_f16(
|
|
|
const float m1,
|
|
|
const uint32_t n_head_log2,
|
|
|
const float logit_softcap,
|
|
|
- const int ne00,
|
|
|
- const int ne01,
|
|
|
- const int ne02,
|
|
|
- const int ne03,
|
|
|
- const int ne10,
|
|
|
- const int ne11,
|
|
|
- const int ne12,
|
|
|
- const int ne13,
|
|
|
- const int ne31,
|
|
|
- const int ne32,
|
|
|
- const int ne33,
|
|
|
- const int nb31,
|
|
|
- const int nb32,
|
|
|
- const int nb33,
|
|
|
- const int nb01,
|
|
|
- const int nb02,
|
|
|
- const int nb03,
|
|
|
- const int nb11,
|
|
|
- const int nb12,
|
|
|
- const int nb13,
|
|
|
- const int nb21,
|
|
|
- const int nb22,
|
|
|
- const int nb23,
|
|
|
- const int ne0,
|
|
|
- const int ne1,
|
|
|
- const int ne2,
|
|
|
- const int ne3) {
|
|
|
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
|
|
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
|
|
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
|
|
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
|
|
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
|
|
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
|
|
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
|
|
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
|
|
|
|
|
// Skip unused kernel variants for faster compilation:
|
|
|
@@ -1359,8 +1337,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
|
|
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
|
|
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
|
|
|
- GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
|
- GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
|
+ GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
|
|
NO_DEVICE_CODE;
|
|
|
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
|
|
}
|