|
|
@@ -1223,7 +1223,9 @@ static __global__ void flash_attn_ext_f16(
|
|
|
const int ne12,
|
|
|
const int ne13,
|
|
|
const int ne31,
|
|
|
+ const int ne32,
|
|
|
const int nb31,
|
|
|
+ const int nb32,
|
|
|
const int nb01,
|
|
|
const int nb02,
|
|
|
const int nb03,
|
|
|
@@ -1288,7 +1290,8 @@ static __global__ void flash_attn_ext_f16(
|
|
|
|
|
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
|
|
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
|
|
- const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
|
|
+ const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
|
|
+ (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
|
|
|
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
|
|
|
|
|
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
|
|
@@ -1327,7 +1330,8 @@ static __global__ void flash_attn_ext_f16(
|
|
|
|
|
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
|
|
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
|
|
- const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
|
|
+ const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
|
|
+ (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
|
|
|
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
|
|
|
|
|
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
|
|
@@ -1348,8 +1352,8 @@ static __global__ void flash_attn_ext_f16(
|
|
|
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
|
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
|
|
|
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
|
|
|
- GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
|
|
- GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
|
|
+ GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
|
|
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
|
|
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
|
|
|
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
|
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|