|
@@ -2049,27 +2049,24 @@ typedef void (flash_attn_ext_f16_t)(
|
|
|
device const char * v,
|
|
device const char * v,
|
|
|
device const char * mask,
|
|
device const char * mask,
|
|
|
device float * dst,
|
|
device float * dst,
|
|
|
- constant int64_t & ne00,
|
|
|
|
|
constant int64_t & ne01,
|
|
constant int64_t & ne01,
|
|
|
constant int64_t & ne02,
|
|
constant int64_t & ne02,
|
|
|
constant int64_t & ne03,
|
|
constant int64_t & ne03,
|
|
|
- constant uint64_t & nb00,
|
|
|
|
|
constant uint64_t & nb01,
|
|
constant uint64_t & nb01,
|
|
|
constant uint64_t & nb02,
|
|
constant uint64_t & nb02,
|
|
|
constant uint64_t & nb03,
|
|
constant uint64_t & nb03,
|
|
|
- constant int64_t & ne10,
|
|
|
|
|
constant int64_t & ne11,
|
|
constant int64_t & ne11,
|
|
|
constant int64_t & ne12,
|
|
constant int64_t & ne12,
|
|
|
constant int64_t & ne13,
|
|
constant int64_t & ne13,
|
|
|
- constant uint64_t & nb10,
|
|
|
|
|
constant uint64_t & nb11,
|
|
constant uint64_t & nb11,
|
|
|
constant uint64_t & nb12,
|
|
constant uint64_t & nb12,
|
|
|
constant uint64_t & nb13,
|
|
constant uint64_t & nb13,
|
|
|
|
|
+ constant uint64_t & nb21,
|
|
|
|
|
+ constant uint64_t & nb22,
|
|
|
|
|
+ constant uint64_t & nb23,
|
|
|
constant uint64_t & nb31,
|
|
constant uint64_t & nb31,
|
|
|
- constant int64_t & ne0,
|
|
|
|
|
constant int64_t & ne1,
|
|
constant int64_t & ne1,
|
|
|
constant int64_t & ne2,
|
|
constant int64_t & ne2,
|
|
|
- constant int64_t & ne3,
|
|
|
|
|
constant float & scale,
|
|
constant float & scale,
|
|
|
constant float & max_bias,
|
|
constant float & max_bias,
|
|
|
constant float & m0,
|
|
constant float & m0,
|
|
@@ -2090,27 +2087,24 @@ kernel void kernel_flash_attn_ext_f16(
|
|
|
device const char * v,
|
|
device const char * v,
|
|
|
device const char * mask,
|
|
device const char * mask,
|
|
|
device float * dst,
|
|
device float * dst,
|
|
|
- constant int64_t & ne00,
|
|
|
|
|
constant int64_t & ne01,
|
|
constant int64_t & ne01,
|
|
|
constant int64_t & ne02,
|
|
constant int64_t & ne02,
|
|
|
constant int64_t & ne03,
|
|
constant int64_t & ne03,
|
|
|
- constant uint64_t & nb00,
|
|
|
|
|
constant uint64_t & nb01,
|
|
constant uint64_t & nb01,
|
|
|
constant uint64_t & nb02,
|
|
constant uint64_t & nb02,
|
|
|
constant uint64_t & nb03,
|
|
constant uint64_t & nb03,
|
|
|
- constant int64_t & ne10,
|
|
|
|
|
constant int64_t & ne11,
|
|
constant int64_t & ne11,
|
|
|
constant int64_t & ne12,
|
|
constant int64_t & ne12,
|
|
|
constant int64_t & ne13,
|
|
constant int64_t & ne13,
|
|
|
- constant uint64_t & nb10,
|
|
|
|
|
constant uint64_t & nb11,
|
|
constant uint64_t & nb11,
|
|
|
constant uint64_t & nb12,
|
|
constant uint64_t & nb12,
|
|
|
constant uint64_t & nb13,
|
|
constant uint64_t & nb13,
|
|
|
|
|
+ constant uint64_t & nb21,
|
|
|
|
|
+ constant uint64_t & nb22,
|
|
|
|
|
+ constant uint64_t & nb23,
|
|
|
constant uint64_t & nb31,
|
|
constant uint64_t & nb31,
|
|
|
- constant int64_t & ne0,
|
|
|
|
|
constant int64_t & ne1,
|
|
constant int64_t & ne1,
|
|
|
constant int64_t & ne2,
|
|
constant int64_t & ne2,
|
|
|
- constant int64_t & ne3,
|
|
|
|
|
constant float & scale,
|
|
constant float & scale,
|
|
|
constant float & max_bias,
|
|
constant float & max_bias,
|
|
|
constant float & m0,
|
|
constant float & m0,
|
|
@@ -2180,10 +2174,6 @@ kernel void kernel_flash_attn_ext_f16(
|
|
|
const short ne22 = ne12;
|
|
const short ne22 = ne12;
|
|
|
const short ne23 = ne13;
|
|
const short ne23 = ne13;
|
|
|
|
|
|
|
|
- const uint nb21 = nb11;
|
|
|
|
|
- const uint nb22 = nb12;
|
|
|
|
|
- const uint nb23 = nb13;
|
|
|
|
|
-
|
|
|
|
|
// broadcast
|
|
// broadcast
|
|
|
const short rk2 = ne02/ne12;
|
|
const short rk2 = ne02/ne12;
|
|
|
const short rk3 = ne03/ne13;
|
|
const short rk3 = ne03/ne13;
|
|
@@ -2247,11 +2237,16 @@ kernel void kernel_flash_attn_ext_f16(
|
|
|
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
|
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // mqk = mqk*scale + mask*slope
|
|
|
|
|
- simdgroup_half8x8 mm;
|
|
|
|
|
- simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
|
|
|
|
|
- simdgroup_multiply(mm, mslope, mm);
|
|
|
|
|
- simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
|
|
|
|
|
|
+ if (mask != q) {
|
|
|
|
|
+ // mqk = mqk*scale + mask*slope
|
|
|
|
|
+ simdgroup_half8x8 mm;
|
|
|
|
|
+ simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
|
|
|
|
|
+ simdgroup_multiply(mm, mslope, mm);
|
|
|
|
|
+ simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ // mqk = mqk*scale
|
|
|
|
|
+ simdgroup_multiply(mqk, mscale, mqk);
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
|
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
|
|
}
|
|
}
|
|
@@ -2425,27 +2420,24 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
|
device const char * v,
|
|
device const char * v,
|
|
|
device const char * mask,
|
|
device const char * mask,
|
|
|
device float * dst,
|
|
device float * dst,
|
|
|
- constant int64_t & ne00,
|
|
|
|
|
constant int64_t & ne01,
|
|
constant int64_t & ne01,
|
|
|
constant int64_t & ne02,
|
|
constant int64_t & ne02,
|
|
|
constant int64_t & ne03,
|
|
constant int64_t & ne03,
|
|
|
- constant uint64_t & nb00,
|
|
|
|
|
constant uint64_t & nb01,
|
|
constant uint64_t & nb01,
|
|
|
constant uint64_t & nb02,
|
|
constant uint64_t & nb02,
|
|
|
constant uint64_t & nb03,
|
|
constant uint64_t & nb03,
|
|
|
- constant int64_t & ne10,
|
|
|
|
|
constant int64_t & ne11,
|
|
constant int64_t & ne11,
|
|
|
constant int64_t & ne12,
|
|
constant int64_t & ne12,
|
|
|
constant int64_t & ne13,
|
|
constant int64_t & ne13,
|
|
|
- constant uint64_t & nb10,
|
|
|
|
|
constant uint64_t & nb11,
|
|
constant uint64_t & nb11,
|
|
|
constant uint64_t & nb12,
|
|
constant uint64_t & nb12,
|
|
|
constant uint64_t & nb13,
|
|
constant uint64_t & nb13,
|
|
|
|
|
+ constant uint64_t & nb21,
|
|
|
|
|
+ constant uint64_t & nb22,
|
|
|
|
|
+ constant uint64_t & nb23,
|
|
|
constant uint64_t & nb31,
|
|
constant uint64_t & nb31,
|
|
|
- constant int64_t & ne0,
|
|
|
|
|
constant int64_t & ne1,
|
|
constant int64_t & ne1,
|
|
|
constant int64_t & ne2,
|
|
constant int64_t & ne2,
|
|
|
- constant int64_t & ne3,
|
|
|
|
|
constant float & scale,
|
|
constant float & scale,
|
|
|
constant float & max_bias,
|
|
constant float & max_bias,
|
|
|
constant float & m0,
|
|
constant float & m0,
|
|
@@ -2521,10 +2513,6 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
|
const short ne22 = ne12;
|
|
const short ne22 = ne12;
|
|
|
const short ne23 = ne13;
|
|
const short ne23 = ne13;
|
|
|
|
|
|
|
|
- const uint nb21 = nb11;
|
|
|
|
|
- const uint nb22 = nb12;
|
|
|
|
|
- const uint nb23 = nb13;
|
|
|
|
|
-
|
|
|
|
|
// broadcast
|
|
// broadcast
|
|
|
const short rk2 = ne02/ne12;
|
|
const short rk2 = ne02/ne12;
|
|
|
const short rk3 = ne03/ne13;
|
|
const short rk3 = ne03/ne13;
|
|
@@ -2589,8 +2577,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
|
|
|
|
|
|
// mqk = mqk*scale + mask*slope
|
|
// mqk = mqk*scale + mask*slope
|
|
|
if (tiisg == 0) {
|
|
if (tiisg == 0) {
|
|
|
- float4 mm = (float4) mp4[ic/4 + cc];
|
|
|
|
|
- mqk = mqk*scale + mm*slope;
|
|
|
|
|
|
|
+ mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f);
|
|
|
|
|
|
|
|
ss4[cc] = mqk;
|
|
ss4[cc] = mqk;
|
|
|
}
|
|
}
|