|
|
@@ -5552,9 +5552,7 @@ void kernel_flash_attn_ext_impl(
|
|
|
|
|
|
constexpr short NC = (C/8)/NSG;
|
|
|
|
|
|
- // note: do not unroll for large heads
|
|
|
- #pragma unroll (DK <= 64 ? NC : 1)
|
|
|
- for (short cc = 0; cc < NC; ++cc) {
|
|
|
+ FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
|
|
|
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
|
|
|
|
|
|
if (DK % 16 != 0) {
|
|
|
@@ -5575,7 +5573,9 @@ void kernel_flash_attn_ext_impl(
|
|
|
k8x8_t mk[2];
|
|
|
q8x8_t mq[2];
|
|
|
|
|
|
- FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
|
|
|
+ // note: too much unroll can tank the performance for large heads
|
|
|
+ #pragma unroll (MIN(DK8/2, 4*NSG))
|
|
|
+ for (short i = 0; i < DK8/2; ++i) {
|
|
|
simdgroup_barrier(mem_flags::mem_none);
|
|
|
|
|
|
simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
|
|
|
@@ -5749,7 +5749,9 @@ void kernel_flash_attn_ext_impl(
|
|
|
pv += 8*NS20;
|
|
|
}
|
|
|
} else {
|
|
|
- FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) {
|
|
|
+ constexpr short NC = (C/8)/2;
|
|
|
+
|
|
|
+ FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
|
|
|
s8x8_t vs[2];
|
|
|
|
|
|
simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
|
|
|
@@ -5952,6 +5954,7 @@ kernel void kernel_flash_attn_ext(
|
|
|
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
|
|
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
|
|
|
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
|
|
|
+ case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
|
|
|
}
|
|
|
#undef FWD_TMPL
|
|
|
#undef FWD_ARGS
|