|
@@ -1380,8 +1380,8 @@ static __global__ void flash_attn_ext_f16(
|
|
|
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
|
|
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
|
|
|
|
|
|
|
|
// kbc == k block continuous, current index in continuous ijk space.
|
|
// kbc == k block continuous, current index in continuous ijk space.
|
|
|
- int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
|
|
|
- const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
|
|
|
|
|
+ int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
|
|
|
+ const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
|
|
|
|
|
|
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
|
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
|
|
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
|
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
|
@@ -1401,7 +1401,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
|
|
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
|
|
|
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
|
|
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
|
|
|
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
|
|
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
|
|
|
- (const half *) (mask + nb33*(sequence % ne33));
|
|
|
|
|
|
|
+ (const half *) (mask + nb33*(sequence % ne33));
|
|
|
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
|
|
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
|
|
|
|
|
|
|
|
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
|
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|