|
|
@@ -4351,7 +4351,7 @@ kernel void kernel_leaky_relu_f32_4(
|
|
|
|
|
|
constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
|
|
|
|
|
|
-constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 24)]];
|
|
|
+constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
|
|
|
|
|
|
// pad the last chunk of C elements of k and v into a an extra pad buffer
|
|
|
kernel void kernel_flash_attn_ext_pad(
|
|
|
@@ -4419,6 +4419,65 @@ kernel void kernel_flash_attn_ext_pad(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+constant int32_t FC_flash_attn_ext_blk_nqptg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 24)]];
|
|
|
+constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 25)]];
|
|
|
+
|
|
|
+// scan the blocks of the mask that are not masked
|
|
|
+// 0 - masked (i.e. full of -INF, skip)
|
|
|
+// 1 - not masked (i.e. at least one element of the mask is not -INF)
|
|
|
+kernel void kernel_flash_attn_ext_blk(
|
|
|
+ constant ggml_metal_kargs_flash_attn_ext_blk & args,
|
|
|
+ device const char * mask,
|
|
|
+ device char * dst,
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
+ ushort tiisg[[thread_index_in_simdgroup]]) {
|
|
|
+ // block size C x Q
|
|
|
+ const int32_t Q = FC_flash_attn_ext_blk_nqptg;
|
|
|
+ const int32_t C = FC_flash_attn_ext_blk_ncpsg;
|
|
|
+
|
|
|
+ constexpr short NW = N_SIMDWIDTH;
|
|
|
+
|
|
|
+ const int32_t i3 = tgpig[2]/args.ne32;
|
|
|
+ const int32_t i2 = tgpig[2]%args.ne32;
|
|
|
+ const int32_t i1 = tgpig[1];
|
|
|
+ const int32_t i0 = tgpig[0];
|
|
|
+
|
|
|
+ char res = i0*C + C > args.ne30 ? 1 : 0;
|
|
|
+
|
|
|
+ device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;
|
|
|
+
|
|
|
+ // fast route
|
|
|
+ if (res == 0) {
|
|
|
+ if (simd_max(*mask_src) > -MAXHALF/2) {
|
|
|
+ res = 1;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // detailed check of the elements of the block
|
|
|
+ if ((C > NW || Q > 1) && res == 0) {
|
|
|
+ half m = -MAXHALF;
|
|
|
+
|
|
|
+ FOR_UNROLL (short j = 0; j < Q; ++j) {
|
|
|
+ FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
|
|
|
+ m = max(m, mask_src[ii*NW]);
|
|
|
+ }
|
|
|
+
|
|
|
+ mask_src += args.nb31/2;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (simd_max(m) > -MAXHALF/2) {
|
|
|
+ res = 1;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ const int32_t nblk1 = ((args.ne01 + Q - 1)/Q);
|
|
|
+ const int32_t nblk0 = ((args.ne30 + C - 1)/C);
|
|
|
+
|
|
|
+ if (tiisg == 0) {
|
|
|
+ dst[((i3*args.ne32 + i2)*nblk1 + i1)*nblk0 + i0] = res;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]];
|
|
|
constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]];
|
|
|
constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]];
|
|
|
@@ -4473,6 +4532,7 @@ void kernel_flash_attn_ext_impl(
|
|
|
device const char * mask,
|
|
|
device const char * sinks,
|
|
|
device const char * pad,
|
|
|
+ device const char * blk,
|
|
|
device char * dst,
|
|
|
threadgroup half * shmem_f16,
|
|
|
uint3 tgpig,
|
|
|
@@ -4538,6 +4598,13 @@ void kernel_flash_attn_ext_impl(
|
|
|
pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
|
|
|
}
|
|
|
|
|
|
+ {
|
|
|
+ const int32_t nblk1 = ((args.ne01 + Q - 1)/Q);
|
|
|
+ const int32_t nblk0 = ((args.ne11 + C - 1)/C);
|
|
|
+
|
|
|
+ blk += (((iq3%args.ne33)*args.ne32 + (iq2%args.ne32))*nblk1 + iq1/Q)*nblk0;
|
|
|
+ }
|
|
|
+
|
|
|
{
|
|
|
q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03;
|
|
|
|
|
|
@@ -4597,11 +4664,14 @@ void kernel_flash_attn_ext_impl(
|
|
|
|
|
|
// loop over the KV cache
|
|
|
// each simdgroup handles blocks of Q rows and C columns
|
|
|
- for (int ic0 = 0; ic0 < args.ne11; ic0 += C) {
|
|
|
- int ic = ic0;
|
|
|
+ for (int ic0 = 0; ; ++ic0) {
|
|
|
+ int ic = ic0*C;
|
|
|
+ if (ic >= args.ne11) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
|
|
|
// the last partial chunk uses the pad buffer as source
|
|
|
- if (FC_flash_attn_ext_has_kvpad && ic0 + C > args.ne11) {
|
|
|
+ if (FC_flash_attn_ext_has_kvpad && ic + C > args.ne11) {
|
|
|
k = pad;
|
|
|
v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
|
|
|
mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
|
|
|
@@ -4640,6 +4710,14 @@ void kernel_flash_attn_ext_impl(
|
|
|
|
|
|
// read the mask into shared mem
|
|
|
if (FC_flash_attn_ext_has_mask) {
|
|
|
+ if (blk[ic0] == 0) {
|
|
|
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
|
|
+ pm2[jj] += NW;
|
|
|
+ }
|
|
|
+
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
|
|
const short j = jj*NSG + sgitg;
|
|
|
|
|
|
@@ -4652,6 +4730,9 @@ void kernel_flash_attn_ext_impl(
|
|
|
pm2[jj] += NW;
|
|
|
}
|
|
|
|
|
|
+#if 0
|
|
|
+ // note: old -INF block optimization - obsoleted by pre-computing non-masked blocks
|
|
|
+
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
// used to detect blocks full of -INF
|
|
|
@@ -4670,6 +4751,7 @@ void kernel_flash_attn_ext_impl(
|
|
|
|
|
|
continue;
|
|
|
}
|
|
|
+#endif
|
|
|
}
|
|
|
|
|
|
// Q*K^T
|
|
|
@@ -4687,26 +4769,24 @@ void kernel_flash_attn_ext_impl(
|
|
|
|
|
|
constexpr short NC = (C/8)/NSG;
|
|
|
|
|
|
- // TODO: not good to unroll for large contexts - not sure why?
|
|
|
+ // note: do not unroll for large heads
|
|
|
+ #pragma unroll (DK <= 64 ? NC : 1)
|
|
|
for (short cc = 0; cc < NC; ++cc) {
|
|
|
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
|
|
|
|
|
|
- if (DK8 % 16 != 0) {
|
|
|
+ if (DK % 16 != 0) {
|
|
|
k8x8_t mk;
|
|
|
q8x8_t mq;
|
|
|
|
|
|
FOR_UNROLL (short i = 0; i < DK8; ++i) {
|
|
|
simdgroup_barrier(mem_flags::mem_none);
|
|
|
|
|
|
- simdgroup_load(mk, pk, NS10, 0, true);
|
|
|
- simdgroup_load(mq, pq, DK);
|
|
|
+ simdgroup_load(mk, pk + 8*i, NS10, 0, true);
|
|
|
+ simdgroup_load(mq, pq + 8*i, DK);
|
|
|
|
|
|
simdgroup_barrier(mem_flags::mem_none);
|
|
|
|
|
|
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
|
|
|
-
|
|
|
- pk += 8;
|
|
|
- pq += 8;
|
|
|
}
|
|
|
} else {
|
|
|
k8x8_t mk[2];
|
|
|
@@ -4715,26 +4795,22 @@ void kernel_flash_attn_ext_impl(
|
|
|
FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
|
|
|
simdgroup_barrier(mem_flags::mem_none);
|
|
|
|
|
|
- simdgroup_load(mk[0], pk + 0*8, NS10, 0, true);
|
|
|
- simdgroup_load(mk[1], pk + 1*8, NS10, 0, true);
|
|
|
+ simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
|
|
|
+ simdgroup_load(mq[1], pq + 1*8 + 16*i, DK);
|
|
|
|
|
|
- simdgroup_load(mq[0], pq + 0*8, DK);
|
|
|
- simdgroup_load(mq[1], pq + 1*8, DK);
|
|
|
+ simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true);
|
|
|
+ simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true);
|
|
|
|
|
|
simdgroup_barrier(mem_flags::mem_none);
|
|
|
|
|
|
simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk);
|
|
|
simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk);
|
|
|
-
|
|
|
- pk += 16;
|
|
|
- pq += 16;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
simdgroup_store(mqk, ps, SH, 0, false);
|
|
|
|
|
|
- pk += 8*(NSG*NS10 - DK8);
|
|
|
- pq += 8*(NSG*0 - DK8);
|
|
|
+ pk += 8*(NSG*NS10);
|
|
|
ps += 8*(NSG);
|
|
|
}
|
|
|
} else {
|
|
|
@@ -4868,27 +4944,50 @@ void kernel_flash_attn_ext_impl(
|
|
|
}
|
|
|
|
|
|
{
|
|
|
- auto sst = ss;
|
|
|
-
|
|
|
device const v_t * pv = (device const v_t *) (v + ic*args.nb21);
|
|
|
|
|
|
pv += 8*sgitg;
|
|
|
|
|
|
- FOR_UNROLL (short cc = 0; cc < C/8; ++cc) {
|
|
|
- s8x8_t vs;
|
|
|
- simdgroup_load(vs, sst, SH, 0, false);
|
|
|
+ if (DV <= 64) {
|
|
|
+ FOR_UNROLL (short cc = 0; cc < C/8; ++cc) {
|
|
|
+ s8x8_t vs;
|
|
|
+ simdgroup_load(vs, ss + 8*cc, SH, 0, false);
|
|
|
|
|
|
- FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
|
|
|
- v8x8_t mv;
|
|
|
+ FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
|
|
|
+ v8x8_t mv[2];
|
|
|
|
|
|
- simdgroup_load(mv, pv, NS20, 0, false);
|
|
|
- simdgroup_multiply_accumulate(lo[ii], vs, mv, lo[ii]);
|
|
|
+ simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false);
|
|
|
+ simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false);
|
|
|
|
|
|
- pv += 8*NSG;
|
|
|
+ simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]);
|
|
|
+ simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]);
|
|
|
+ }
|
|
|
+
|
|
|
+ pv += 8*NS20;
|
|
|
}
|
|
|
+ } else {
|
|
|
+ FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) {
|
|
|
+ s8x8_t vs[2];
|
|
|
+
|
|
|
+ simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
|
|
|
+ simdgroup_load(vs[1], ss + 16*cc + 8, SH, 0, false);
|
|
|
|
|
|
- pv += 8*(NS20 - NO*NSG);
|
|
|
- sst += 8;
|
|
|
+ FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
|
|
|
+ v8x8_t mv[4];
|
|
|
+
|
|
|
+ simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);
|
|
|
+ simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);
|
|
|
+ simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);
|
|
|
+ simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);
|
|
|
+
|
|
|
+ simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]);
|
|
|
+ simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]);
|
|
|
+ simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]);
|
|
|
+ simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]);
|
|
|
+ }
|
|
|
+
|
|
|
+ pv += 2*8*NS20;
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -5002,7 +5101,7 @@ void kernel_flash_attn_ext_impl(
|
|
|
|
|
|
device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
|
|
|
|
|
|
- const float scale = 1.0f/S[jj];
|
|
|
+ const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj];
|
|
|
|
|
|
if (DV4 % NW == 0) {
|
|
|
FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {
|
|
|
@@ -5047,8 +5146,8 @@ template<
|
|
|
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
|
|
|
short DK, // K head size
|
|
|
short DV, // V head size
|
|
|
- short Q = 8, // queries per threadgroup
|
|
|
- short C = 64> // cache items per threadgroup
|
|
|
+ short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup
|
|
|
+ short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup
|
|
|
kernel void kernel_flash_attn_ext(
|
|
|
constant ggml_metal_kargs_flash_attn_ext & args,
|
|
|
device const char * q,
|
|
|
@@ -5057,13 +5156,14 @@ kernel void kernel_flash_attn_ext(
|
|
|
device const char * mask,
|
|
|
device const char * sinks,
|
|
|
device const char * pad,
|
|
|
+ device const char * blk,
|
|
|
device char * dst,
|
|
|
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
#define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C
|
|
|
-#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
|
|
|
+#define FWD_ARGS args, q, k, v, mask, sinks, pad, blk, dst, shmem_f16, tgpig, tiisg, sgitg
|
|
|
switch (FC_flash_attn_ext_nsg) {
|
|
|
// note: disabled cases to reduce library load time
|
|
|
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
|
|
@@ -5210,9 +5310,9 @@ template<
|
|
|
void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
|
|
|
short DK, // K head size
|
|
|
short DV, // V head size
|
|
|
- short NE = 4, // head elements per thread
|
|
|
- short Q = 1, // queries per threadgroup
|
|
|
- short C = 32, // cache items per threadgroup
|
|
|
+ short NE, // head elements per thread
|
|
|
+ short Q, // queries per threadgroup
|
|
|
+ short C, // cache items per threadgroup
|
|
|
short NSG> // number of simd groups
|
|
|
void kernel_flash_attn_ext_vec_impl(
|
|
|
constant ggml_metal_kargs_flash_attn_ext_vec & args,
|
|
|
@@ -5327,8 +5427,8 @@ void kernel_flash_attn_ext_vec_impl(
|
|
|
|
|
|
// loop over the KV cache
|
|
|
// each simdgroup handles blocks of Q rows and C columns
|
|
|
- for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) {
|
|
|
- int ic = ic0 + C*sgitg;
|
|
|
+ for (int ic0 = iwg*NSG + sgitg; ; ic0 += NWG*NSG) {
|
|
|
+ int ic = ic0*C;
|
|
|
if (ic >= args.ne11) {
|
|
|
break;
|
|
|
}
|
|
|
@@ -5621,7 +5721,7 @@ void kernel_flash_attn_ext_vec_impl(
|
|
|
device float4 * dst4 = (device float4 *) dst;
|
|
|
device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results
|
|
|
|
|
|
- const float S = NWG == 1 ? 1.0f/ss[0] : 1.0f;
|
|
|
+ const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f;
|
|
|
|
|
|
// interleave the workgroup data
|
|
|
for (short i = tiisg; i < DV4; i += NW) {
|
|
|
@@ -5659,8 +5759,8 @@ template<
|
|
|
short DK, // K head size
|
|
|
short DV, // V head size
|
|
|
short NE = 4, // head elements per thread
|
|
|
- short Q = 1, // queries per threadgroup
|
|
|
- short C = 32> // cache items per threadgroup
|
|
|
+ short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup
|
|
|
+ short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup
|
|
|
kernel void kernel_flash_attn_ext_vec(
|
|
|
constant ggml_metal_kargs_flash_attn_ext_vec & args,
|
|
|
device const char * q,
|
|
|
@@ -5799,7 +5899,8 @@ kernel void kernel_flash_attn_ext_vec_reduce(
|
|
|
const float m = simd_max(M);
|
|
|
const float ms = exp(M - m);
|
|
|
|
|
|
- S = 1.0f/simd_sum(S*ms);
|
|
|
+ S = simd_sum(S*ms);
|
|
|
+ S = S == 0.0f ? 0.0f : 1.0f/S;
|
|
|
|
|
|
const short DV4 = DV/4;
|
|
|
|