|
@@ -3887,6 +3887,11 @@ kernel void kernel_flash_attn_ext_vec(
|
|
|
sm[tiisg] = pm[ic + tiisg];
|
|
sm[tiisg] = pm[ic + tiisg];
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ // skip -INF blocks
|
|
|
|
|
+ if (simd_max(sm[tiisg]) == -INFINITY) {
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
// Q*K^T
|
|
// Q*K^T
|
|
|
{
|
|
{
|
|
|
// each simdgroup processes 1 query and NE (NW/NL) head elements
|
|
// each simdgroup processes 1 query and NE (NW/NL) head elements
|