|
|
@@ -3356,8 +3356,8 @@ kernel void kernel_flash_attn_ext_vec(
|
|
|
const short D4 = D/4;
|
|
|
const short D16 = D/16;
|
|
|
const short NW = N_SIMDWIDTH;
|
|
|
- const short NL = NW/4;
|
|
|
- const short SH = 2*C; // shared memory per simdgroup
|
|
|
+ const short NL = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
|
|
|
+ const short SH = 2*C; // shared memory per simdgroup
|
|
|
|
|
|
const short T = D + nsg*SH; // shared memory size per query in (half)
|
|
|
|
|
|
@@ -3448,7 +3448,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
|
|
|
|
// Q*K^T
|
|
|
{
|
|
|
- // each simdgroup processes 1 query and 4 keys
|
|
|
+ // each simdgroup processes 1 query and 4 (NW/NL) keys
|
|
|
for (short cc = 0; cc < C/4; ++cc) {
|
|
|
qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
|
|
|
|
|
|
@@ -3646,7 +3646,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
|
half, half4, half4x4, \
|
|
|
half4x4
|
|
|
|
|
|
-typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
|
|
|
+typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>) flash_attn_ext_vec_t;
|
|
|
|
|
|
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
|
|
|
#if defined(GGML_METAL_USE_BF16)
|