Browse Source

metal : improve FA + improve MoE (#12612)

* ggml : FA with different K, V head sizes (CPU)

ggml-ci

* metal : add FA with HS=192

* metal : extend FA to support different K and V head sizes

ggml-ci

* metal : add FA vector kernels for heads K 192 and V 128

ggml-ci

* ggml : restrict op on other backends to equal head sizes

ggml-ci

* metal : optimize FA-vec kernel

ggml-ci

* metal : FA remove mq registers

* metal : improve MoE mul_mat_id condition

ggml-ci

* metal : fix comments + remove unnecessary addition

ggml-ci

* metal : avoid too much shared memory usage with mul_mat_id

ggml-ci
Georgi Gerganov 9 months ago
parent
commit
b4ae50810e

+ 5 - 5
ggml/include/ggml.h

@@ -1791,11 +1791,11 @@ extern "C" {
 
 #define GGML_KQ_MASK_PAD 64
 
-    // q:    [n_embd, n_batch,     n_head,    1]
-    // k:    [n_embd, n_kv,        n_head_kv, 1]
-    // v:    [n_embd, n_kv,        n_head_kv, 1] !! not transposed !!
-    // mask: [n_kv,   n_batch_pad, 1,         1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
-    // res:  [n_embd, n_head,      n_batch,   1] !! permuted !!
+    // q:    [n_embd_k, n_batch,     n_head,    1]
+    // k:    [n_embd_k, n_kv,        n_head_kv, 1]
+    // v:    [n_embd_v, n_kv,        n_head_kv, 1] !! not transposed !!
+    // mask: [n_kv,     n_batch_pad, 1,         1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
+    // res:  [n_embd_v, n_head,      n_batch,   1] !! permuted !!
     GGML_API struct ggml_tensor * ggml_flash_attn_ext(
             struct ggml_context * ctx,
             struct ggml_tensor  * q,

+ 25 - 25
ggml/src/ggml-cpu/ggml-cpu.c

@@ -12238,10 +12238,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int64_t D = neq0;
-    const int64_t N = neq1;
+    const int64_t DK = nek0;
+    const int64_t DV = nev0;
+    const int64_t N  = neq1;
 
-    GGML_ASSERT(ne0 == D);
+    GGML_ASSERT(ne0 == DV);
     GGML_ASSERT(ne2 == N);
 
     // input tensor rows must be contiguous
@@ -12249,12 +12250,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
     GGML_ASSERT(nbk0 == ggml_type_size(k->type));
     GGML_ASSERT(nbv0 == ggml_type_size(v->type));
 
-    GGML_ASSERT(neq0 == D);
-    GGML_ASSERT(nek0 == D);
-    GGML_ASSERT(nev0 == D);
+    GGML_ASSERT(neq0 == DK);
+    GGML_ASSERT(nek0 == DK);
+    GGML_ASSERT(nev0 == DV);
 
     GGML_ASSERT(neq1 == N);
-    GGML_ASSERT(nev0 == D);
 
     // dst cannot be transposed or permuted
     GGML_ASSERT(nb0 == sizeof(float));
@@ -12320,15 +12320,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
         float S = 0.0f;      // sum
         float M = -INFINITY; // maximum KQ value
 
-        float       * VKQ32 = (float       *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
-        float       * V32   =                 (VKQ32 + 1*D); // (temporary) FP32 V buffer
-        ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
-        ggml_fp16_t * Q_q   = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
+        float       * VKQ32 = (float       *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
+        float       * V32   =                 (VKQ32 + 1*DV); // (temporary) FP32 V buffer
+        ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
+        ggml_fp16_t * Q_q   = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16
 
         if (v->type == GGML_TYPE_F16) {
-            memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
+            memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
         } else {
-            memset(VKQ32, 0, D*sizeof(float));
+            memset(VKQ32, 0, DV*sizeof(float));
         }
 
         const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
@@ -12342,7 +12342,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
         const int iv2 = iq2 / rv2;
 
         const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
-        q_to_vec_dot(pq, Q_q, D);
+        q_to_vec_dot(pq, Q_q, DK);
 
         // online softmax / attention
         // loop over n_kv and n_head_kv
@@ -12356,7 +12356,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
             float s; // KQ value
 
             const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
-            kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
+            kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
 
             s = s*scale; // scale KQ value
 
@@ -12380,14 +12380,14 @@ static void ggml_compute_forward_flash_attn_ext_f16(
                     ms = expf(Mold - M);
 
                     // V = V*expf(Mold - M)
-                    ggml_vec_scale_f16(D, VKQ16, ms);
+                    ggml_vec_scale_f16(DV, VKQ16, ms);
                 } else {
                     // no new maximum, ms == 1.0f, vs != 1.0f
                     vs = expf(s - M);
                 }
 
                 // V += v*expf(s - M)
-                ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
+                ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
             } else {
                 if (s > M) {
                     // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
@@ -12395,30 +12395,30 @@ static void ggml_compute_forward_flash_attn_ext_f16(
                     ms = expf(Mold - M);
 
                     // V = V*expf(Mold - M)
-                    ggml_vec_scale_f32(D, VKQ32, ms);
+                    ggml_vec_scale_f32(DV, VKQ32, ms);
                 } else {
                     // no new maximum, ms == 1.0f, vs != 1.0f
                     vs = expf(s - M);
                 }
 
-                v_to_float(v_data, V32, D);
+                v_to_float(v_data, V32, DV);
 
                 // V += v*expf(s - M)
-                ggml_vec_mad_f32(D, VKQ32, V32, vs);
+                ggml_vec_mad_f32(DV, VKQ32, V32, vs);
             }
 
             S = S*ms + vs; // scale and increment sum with partial sum
         }
 
         if (v->type == GGML_TYPE_F16) {
-            for (int64_t d = 0; d < D; ++d) {
+            for (int64_t d = 0; d < DV; ++d) {
                 VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
             }
         }
 
         // V /= S
         const float S_inv = 1.0f/S;
-        ggml_vec_scale_f32(D, VKQ32, S_inv);
+        ggml_vec_scale_f32(DV, VKQ32, S_inv);
 
         // dst indices
         const int i1 = iq1;
@@ -15277,7 +15277,6 @@ struct ggml_cplan ggml_graph_plan(
         size_t cur = 0;
 
         if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) {
-
             switch (node->op) {
                 case GGML_OP_CPY:
                 case GGML_OP_DUP:
@@ -15386,9 +15385,10 @@ struct ggml_cplan ggml_graph_plan(
                     } break;
                 case GGML_OP_FLASH_ATTN_EXT:
                     {
-                        const int64_t ne00 = node->src[0]->ne[0]; // D
+                        const int64_t ne10 = node->src[1]->ne[0]; // DK
+                        const int64_t ne20 = node->src[2]->ne[0]; // DV
 
-                        cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
+                        cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
                     } break;
                 case GGML_OP_FLASH_ATTN_BACK:
                     {

+ 7 - 0
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -3232,6 +3232,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
 #ifndef FLASH_ATTN_AVAILABLE
             return false;
 #endif // FLASH_ATTN_AVAILABLE
+            if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
+                // different head sizes of K and V are not supported yet
+                return false;
+            }
+            if (op->src[0]->ne[0] == 192) {
+                return false;
+            }
             if (op->src[0]->ne[3] != 1) {
                 return false;
             }

+ 6 - 3
ggml/src/ggml-metal/ggml-metal-impl.h

@@ -219,9 +219,12 @@ typedef struct {
     int32_t  ne11;
     int32_t  ne_12_2; // assume K and V are same shape
     int32_t  ne_12_3;
-    uint64_t nb_12_1;
-    uint64_t nb_12_2;
-    uint64_t nb_12_3;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    uint64_t nb21;
+    uint64_t nb22;
+    uint64_t nb23;
     uint64_t nb31;
     int32_t  ne1;
     int32_t  ne2;

File diff suppressed because it is too large
+ 500 - 407
ggml/src/ggml-metal/ggml-metal.m


+ 292 - 232
ggml/src/ggml-metal/ggml-metal.metal

@@ -48,7 +48,7 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
 
 template <typename type4>
 void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
-    reg = (type4)(*(src + il));
+    reg = (type4)(*(src));
 }
 
 #if defined(GGML_METAL_USE_BF16)
@@ -56,6 +56,11 @@ template <typename type4x4>
 void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
     reg = (type4x4)(*src);
 }
+
+template <typename type4>
+void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {
+    reg = (type4)(*(src));
+}
 #endif
 
 template <typename type4x4>
@@ -3100,7 +3105,8 @@ template<
     typename vd4x4_t, // key type in device memory
     short nl_v,
     void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
-    short D,         // head size
+    short DK,        // K head size
+    short DV,        // V head size
     short Q  = 8,    // queries per threadgroup
     short KV = 8,    // key/value processed per each simdgroup
     short C  = 32>   // cache items per threadgroup
@@ -3122,20 +3128,23 @@ kernel void kernel_flash_attn_ext(
     const int iq2 = tgpig[1];
     const int iq1 = tgpig[0]*Q;
 
-    const short D4  = D/4;
-    const short D8  = D/8;
-    const short D16 = D/16;
+    const short DK4  = DK/4;
+    const short DK8  = DK/8;
+    const short DK16 = DK/16;
+    const short DV4  = DV/4;
+    const short DV8  = DV/8;
+    const short DV16 = DV/16;
     const short NW  = N_SIMDWIDTH;
     const short SH  = (2*C + Q); // shared memory per simdgroup (s_t == float)
 
     const short TS = nsg*SH;   // shared memory size per query in (s_t == float)
-    const short T  = D + 2*TS; // shared memory size per query in (half)
+    const short T  = DK + 2*TS; // shared memory size per query in (half)
 
-    threadgroup q_t  * sq  = (threadgroup q_t  *) (shmem_f16 +              0*D); // holds the query data
-    threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +              0*D); // same as above but in q4_t
-    threadgroup o_t  * so  = (threadgroup o_t  *) (shmem_f16 +              0*D); // reuse query data for accumulation
-    threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 +              0*D); // same as above but in o4_t
-    threadgroup s_t  * ss  = (threadgroup s_t  *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
+    threadgroup q_t  * sq  = (threadgroup q_t  *) (shmem_f16 +              0*DK); // holds the query data
+    threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +              0*DK); // same as above but in q4_t
+    threadgroup o_t  * so  = (threadgroup o_t  *) (shmem_f16 +              0*DK); // reuse query data for accumulation
+    threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 +              0*DK); // same as above but in o4_t
+    threadgroup s_t  * ss  = (threadgroup s_t  *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
 
     threadgroup k_t    * sk    = (threadgroup k_t    *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
     threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
@@ -3144,23 +3153,23 @@ kernel void kernel_flash_attn_ext(
     threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
 
     // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
-    o8x8_t lo[D8];
+    o8x8_t lo[DV8];
 
     // load heads from Q to shared memory
     for (short j = sgitg; j < Q; j += nsg) {
         device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
 
-        for (short i = tiisg; i < D4; i += NW) {
+        for (short i = tiisg; i < DK4; i += NW) {
             if (iq1 + j < args.ne01) {
-                sq4[j*D4 + i] = (q4_t) q4[i];
+                sq4[j*DK4 + i] = (q4_t) q4[i];
             } else {
-                sq4[j*D4 + i] = (q4_t) 0.0f;
+                sq4[j*DK4 + i] = (q4_t) 0.0f;
             }
         }
     }
 
     // zero out lo
-    for (short i = 0; i < D8; ++i) {
+    for (short i = 0; i < DV8; ++i) {
         lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
     }
 
@@ -3190,13 +3199,6 @@ kernel void kernel_flash_attn_ext(
         const short ikv2 = iq2/(args.ne02/args.ne_12_2);
         const short ikv3 = iq3/(args.ne03/args.ne_12_3);
 
-        // load the queries from shared memory into local memory
-        q8x8_t mq[D8];
-
-        for (short i = 0; i < D8; ++i) {
-            simdgroup_load(mq[i], sq + i*8, D);
-        }
-
         const bool has_mask = mask != q;
 
         half slope = 1.0f;
@@ -3249,20 +3251,22 @@ kernel void kernel_flash_attn_ext(
                     // this is compile-time check, so it does not have runtime overhead
                     if (is_same<kd4x4_t, k4x4_t>::value) {
                         // we can read directly from global memory
-                        device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+                        device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
 
-                        #pragma unroll(D8)
-                        for (short i = 0; i < D8; ++i) {
+                        #pragma unroll(DK8)
+                        for (short i = 0; i < DK8; ++i) {
                             k8x8_t mk;
-                            simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
+                            simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10
 
-                            simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
+                            q8x8_t mq;
+                            simdgroup_load(mq, sq + i*8, DK);
+                            simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
                         }
                     } else {
-                        for (short ii = 0; ii < D16; ii += 4) {
-                            device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+                        for (short ii = 0; ii < DK16; ii += 4) {
+                            device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
 
-                            if (D16%4 == 0) {
+                            if (DK16%4 == 0) {
                                 // the head is evenly divisible by 4*16 = 64, so no need for bound checks
                                 {
                                     k4x4_t tmp;
@@ -3275,15 +3279,18 @@ kernel void kernel_flash_attn_ext(
                                 #pragma unroll(4)
                                 for (short k = 0; k < 4; ++k) {
                                     k8x8_t mk;
+                                    q8x8_t mq;
 
                                     simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
-                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
+                                    simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
+                                    simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
 
                                     simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
-                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
+                                    simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
+                                    simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
                                 }
                             } else {
-                                if (ii + tx < D16) {
+                                if (ii + tx < DK16) {
                                     k4x4_t tmp;
                                     deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
                                     sk4x4[4*ty + tx] = tmp;
@@ -3291,14 +3298,17 @@ kernel void kernel_flash_attn_ext(
 
                                 simdgroup_barrier(mem_flags::mem_threadgroup);
 
-                                for (short k = 0; k < 4 && ii + k < D16; ++k) {
+                                for (short k = 0; k < 4 && ii + k < DK16; ++k) {
                                     k8x8_t mk;
+                                    q8x8_t mq;
 
                                     simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
-                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
+                                    simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
+                                    simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
 
                                     simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
-                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
+                                    simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
+                                    simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
                                 }
                             }
                         }
@@ -3350,8 +3360,8 @@ kernel void kernel_flash_attn_ext(
                 s8x8_t mm;
                 simdgroup_load(mm, ss + 2*C, TS, 0, false);
 
-                #pragma unroll(D8)
-                for (short i = 0; i < D8; ++i) {
+                #pragma unroll(DV8)
+                for (short i = 0; i < DV8; ++i) {
                     simdgroup_multiply(lo[i], mm, lo[i]);
                 }
             }
@@ -3364,20 +3374,20 @@ kernel void kernel_flash_attn_ext(
 
                     if (is_same<vd4x4_t, v4x4_t>::value) {
                         // we can read directly from global memory
-                        device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+                        device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
 
-                        #pragma unroll(D8)
-                        for (short i = 0; i < D8; ++i) {
+                        #pragma unroll(DV8)
+                        for (short i = 0; i < DV8; ++i) {
                             v8x8_t mv;
-                            simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
+                            simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
 
                             simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
                         }
                     } else {
-                        for (short ii = 0; ii < D16; ii += 4) {
-                            device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+                        for (short ii = 0; ii < DV16; ii += 4) {
+                            device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
 
-                            if (D16%4 == 0) {
+                            if (DV16%4 == 0) {
                                 // no need for bound checks
                                 {
                                     v4x4_t tmp;
@@ -3398,7 +3408,7 @@ kernel void kernel_flash_attn_ext(
                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
                                 }
                             } else {
-                                if (ii + tx < D16) {
+                                if (ii + tx < DV16) {
                                     v4x4_t tmp;
                                     deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
                                     sv4x4[4*ty + tx] = tmp;
@@ -3406,7 +3416,7 @@ kernel void kernel_flash_attn_ext(
 
                                 simdgroup_barrier(mem_flags::mem_threadgroup);
 
-                                for (short k = 0; k < 4 && ii + k < D16; ++k) {
+                                for (short k = 0; k < 4 && ii + k < DV16; ++k) {
                                     v8x8_t mv;
 
                                     simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
@@ -3440,8 +3450,8 @@ kernel void kernel_flash_attn_ext(
 
         // each simdgroup stores its output to shared memory, reusing sq
         if (sgitg == sg) {
-            for (short i = 0; i < D8; ++i) {
-                simdgroup_store(lo[i], so + i*8, D, 0, false);
+            for (short i = 0; i < DV8; ++i) {
+                simdgroup_store(lo[i], so + i*8, DV, 0, false);
             }
         }
 
@@ -3480,11 +3490,11 @@ kernel void kernel_flash_attn_ext(
                 simdgroup_load(ms0, ss + 2*C,         TS, 0, false);
                 simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
 
-                #pragma unroll(D8)
-                for (short i = 0; i < D8; ++i) {
+                #pragma unroll(DV8)
+                for (short i = 0; i < DV8; ++i) {
                     o8x8_t t;
 
-                    simdgroup_load    (t, so + i*8, D, 0, false);
+                    simdgroup_load    (t, so + i*8, DV, 0, false);
                     simdgroup_multiply(t, ms1, t);
 
                     simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
@@ -3495,8 +3505,8 @@ kernel void kernel_flash_attn_ext(
 
     // store result to shared memory (reuse sq)
     if (sgitg == 0) {
-        for (short i = 0; i < D8; ++i) {
-            simdgroup_store(lo[i], so + i*8, D, 0, false);
+        for (short i = 0; i < DV8; ++i) {
+            simdgroup_store(lo[i], so + i*8, DV, 0, false);
         }
     }
 
@@ -3507,8 +3517,8 @@ kernel void kernel_flash_attn_ext(
         for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
             const float S = ss[j*TS + 0];
 
-            for (short i = tiisg; i < D4; i += NW) {
-                dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
+            for (short i = tiisg; i < DV4; i += NW) {
+                dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S;
             }
         }
     }
@@ -3525,80 +3535,94 @@ kernel void kernel_flash_attn_ext(
     float,          simdgroup_float8x8, \
     half,  half4,   simdgroup_half8x8
 
-typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
+typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
 
-template [[host_name("kernel_flash_attn_ext_f16_h64" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  64>;
-template [[host_name("kernel_flash_attn_ext_f16_h80" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  80>;
-template [[host_name("kernel_flash_attn_ext_f16_h96" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  96>;
-template [[host_name("kernel_flash_attn_ext_f16_h112")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  112>;
-template [[host_name("kernel_flash_attn_ext_f16_h128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  128>;
-template [[host_name("kernel_flash_attn_ext_f16_h256")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  256>;
+template [[host_name("kernel_flash_attn_ext_f16_h64" )]]         kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  64,  64>;
+template [[host_name("kernel_flash_attn_ext_f16_h80" )]]         kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  80,  80>;
+template [[host_name("kernel_flash_attn_ext_f16_h96" )]]         kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  96,  96>;
+template [[host_name("kernel_flash_attn_ext_f16_h112")]]         kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  112, 112>;
+template [[host_name("kernel_flash_attn_ext_f16_h128")]]         kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  128, 128>;
+template [[host_name("kernel_flash_attn_ext_f16_h192")]]         kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  192, 192>;
+template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  192, 128>;
+template [[host_name("kernel_flash_attn_ext_f16_h256")]]         kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  256, 256>;
 
 #if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 64>;
-template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 80>;
-template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 96>;
-template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 112>;
-template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 128>;
-template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 256>;
+template [[host_name("kernel_flash_attn_ext_bf16_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 64,  64>;
+template [[host_name("kernel_flash_attn_ext_bf16_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 80,  80>;
+template [[host_name("kernel_flash_attn_ext_bf16_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 96,  96>;
+template [[host_name("kernel_flash_attn_ext_bf16_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 112, 112>;
+template [[host_name("kernel_flash_attn_ext_bf16_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 128, 128>;
+template [[host_name("kernel_flash_attn_ext_bf16_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 192, 192>;
+template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 192, 128>;
+template [[host_name("kernel_flash_attn_ext_bf16_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 256, 256>;
 #endif
 
-template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256>;
-
-template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256>;
-
-template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256>;
-
-template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256>;
-
-template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64,  64>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80,  80>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96,  96>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128, 128>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
+template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
+
+template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64,  64>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80,  80>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96,  96>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128, 128>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
+template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
+
+template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64,  64>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80,  80>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96,  96>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128, 128>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
+template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
+
+template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64,  64>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80,  80>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96,  96>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128, 128>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
+template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
+
+template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64,  64>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80,  80>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96,  96>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128, 128>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
+template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
 
 #undef FA_TYPES
 
 template<
-    typename q4_t,    // query types in shared memory
-    typename q4x4_t,
-    typename k4x4_t,  // key types in shared memory
-    typename v4x4_t,  // value types in shared memory
-    typename qk_t,    // Q*K types
-    typename s_t,     // soft-max types
+    typename q4_t,  // query types in shared memory
+    typename k4_t,  // key types in shared memory
+    typename v4_t,  // value types in shared memory
+    typename qk_t,  // Q*K types
+    typename s_t,   // soft-max types
     typename s4_t,
-    typename s4x4_t,
-    typename o4x4_t,  // attention accumulation types
-    typename kd4x4_t, // key type in device memory
+    typename o4_t,  // attention accumulation types
+    typename kd4_t, // key type in device memory
     short nl_k,
-    void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
-    typename vd4x4_t, // key type in device memory
+    void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
+    typename vd4_t, // key type in device memory
     short nl_v,
-    void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
-    short D,         // head size
-    short Q  = 1,    // queries per threadgroup
-    short C  = 32>   // cache items per threadgroup
+    void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
+    short DK,       // K head size
+    short DV,       // V head size
+    short NE = 4,   // head elements per thread
+    short Q  = 1,   // queries per threadgroup
+    short C  = 32>  // cache items per threadgroup
 kernel void kernel_flash_attn_ext_vec(
         constant ggml_metal_kargs_flash_attn_ext & args,
         device const char * q,
@@ -3617,29 +3641,28 @@ kernel void kernel_flash_attn_ext_vec(
     const int iq2 = tgpig[1];
     const int iq1 = tgpig[0];
 
-    const short D4  = D/4;
-    const short D16 = D/16;
+    const short DK4 = DK/4;
+    const short DV4 = DV/4;
     const short NW  = N_SIMDWIDTH;
-    const short NL  = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
-    const short SH  = 2*C;  // shared memory per simdgroup
+    const short NL  = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
+    const short SH  = 2*C;   // shared memory per simdgroup
 
-    const short T = D + nsg*SH; // shared memory size per query in (half)
+    const short T = DK + nsg*SH; // shared memory size per query in (half)
 
-  //threadgroup q_t    * sq    = (threadgroup q_t    *) (shmem_f16 +                0*D); // holds the query data
-    threadgroup q4_t   * sq4   = (threadgroup q4_t   *) (shmem_f16 +                0*D); // same as above but in q4_t
-    threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 +                0*D); // same as above but in q4x4_t
-    threadgroup s_t    * ss    = (threadgroup s_t    *) (shmem_f16 + sgitg*SH     + Q*D); // scratch buffer for attention
-    threadgroup s4_t   * ss4   = (threadgroup s4_t   *) (shmem_f16 + sgitg*SH     + Q*D); // same as above but in s4_t
-    threadgroup half   * sm    = (threadgroup half   *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask
-    threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D      + Q*T); // scratch buffer for the results
+  //threadgroup q_t  * sq  = (threadgroup q_t  *) (shmem_f16 +                0*DK); // holds the query data
+    threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +                0*DK); // same as above but in q4_t
+    threadgroup s_t  * ss  = (threadgroup s_t  *) (shmem_f16 + sgitg*SH     + Q*DK); // scratch buffer for attention
+    threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH     + Q*DK); // same as above but in s4_t
+    threadgroup half * sm  = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask
+    threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV     + Q*T);  // scratch buffer for the results
 
-    // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
-    o4x4_t lo[D16/NL];
+    // store the result for all queries in local memory (the O matrix from the paper)
+    o4_t lo[DV4/NL];
 
     // load heads from Q to shared memory
     device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
 
-    for (short i = tiisg; i < D4; i += NW) {
+    for (short i = tiisg; i < DK4; i += NW) {
         if (iq1 < args.ne01) {
             sq4[i] = (q4_t) q4[i];
         } else {
@@ -3648,8 +3671,8 @@ kernel void kernel_flash_attn_ext_vec(
     }
 
     // zero out lo
-    for (short i = 0; i < D16/NL; ++i) {
-        lo[i] = (o4x4_t) 0.0f;
+    for (short i = 0; i < DV4/NL; ++i) {
+        lo[i] = (o4_t) 0.0f;
     }
 
     // zero out shared memory SH
@@ -3674,14 +3697,6 @@ kernel void kernel_flash_attn_ext_vec(
         const short ikv2 = iq2/(args.ne02/args.ne_12_2);
         const short ikv3 = iq3/(args.ne03/args.ne_12_3);
 
-        // load the queries from shared memory into local memory
-        q4x4_t mq[D16/NL];
-
-        #pragma unroll(D16/NL)
-        for (short ii = 0; ii < D16; ii += NL) {
-            mq[ii/NL] = sq4x4[ii + tx];
-        }
-
         const bool has_mask = mask != q;
 
         // pointer to the mask
@@ -3713,43 +3728,56 @@ kernel void kernel_flash_attn_ext_vec(
 
             // Q*K^T
             {
-                // each simdgroup processes 1 query and 4 (NW/NL) keys
-                for (short cc = 0; cc < C/4; ++cc) {
-                    qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
+                // each simdgroup processes 1 query and NE (NW/NL) head elements
+                for (short cc = 0; cc < C/NE; ++cc) {
+                    qk_t mqk = 0.0f;
 
-                    device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+                    device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
 
-                    #pragma unroll(D16/NL)
-                    for (short ii = 0; ii < D16; ii += NL) {
+                    #pragma unroll(DK4/NL)
+                    for (short ii = 0; ii < DK4; ii += NL) {
                         const short i = ii + tx;
 
-                        k4x4_t mk;
-                        deq_k(pk + i/nl_k, i%nl_k, mk);
+                        k4_t mk;
+                        deq_k_t4(pk + i/nl_k, i%nl_k, mk);
 
                         // note: this is less precise than the version below
-                        //mqka[0] += dot(mq[ii/NL][0], mk[0]);
-                        //mqka[1] += dot(mq[ii/NL][1], mk[1]);
-                        //mqka[2] += dot(mq[ii/NL][2], mk[2]);
-                        //mqka[3] += dot(mq[ii/NL][3], mk[3]);
-
-                        mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]);
-                        mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]);
-                        mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]);
-                        mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]);
+                        //mqka[0] += dot(mq[0], mk[0]);
+                        //mqka[1] += dot(mq[1], mk[1]);
+                        //mqka[2] += dot(mq[2], mk[2]);
+                        //mqka[3] += dot(mq[3], mk[3]);
+
+                        //q4x4_t mq = sq4x4[i];
+                        //mqka[0] += dot((float4) mq[0], (float4) mk[0]);
+                        //mqka[1] += dot((float4) mq[1], (float4) mk[1]);
+                        //mqka[2] += dot((float4) mq[2], (float4) mk[2]);
+                        //mqka[3] += dot((float4) mq[3], (float4) mk[3]);
+
+                        mqk += dot((float4) mk, (float4) sq4[i]);
                     }
 
-                    qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
+                    static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails
 
-                    // simdgroup reduce
+                    // simdgroup reduce (NE = 4)
                     // [ 0 ..  7] -> [ 0]
                     // [ 8 .. 15] -> [ 8]
                     // [16 .. 23] -> [16]
                     // [24 .. 31] -> [24]
-                  //mqk += simd_shuffle_down(mqk, 16);
-                  //mqk += simd_shuffle_down(mqk,  8);
-                    mqk += simd_shuffle_down(mqk,  4);
-                    mqk += simd_shuffle_down(mqk,  2);
-                    mqk += simd_shuffle_down(mqk,  1);
+                    if (NE <= 1) {
+                        mqk += simd_shuffle_down(mqk, 16);
+                    }
+                    if (NE <= 2) {
+                        mqk += simd_shuffle_down(mqk,  8);
+                    }
+                    if (NE <= 4) {
+                        mqk += simd_shuffle_down(mqk,  4);
+                    }
+                    if (NE <= 8) {
+                        mqk += simd_shuffle_down(mqk,  2);
+                    }
+                    if (NE <= 16) {
+                        mqk += simd_shuffle_down(mqk,  1);
+                    }
 
                     // mqk = mqk*scale + mask*slope
                     if (tx == 0) {
@@ -3759,9 +3787,9 @@ kernel void kernel_flash_attn_ext_vec(
                             mqk = args.logit_softcap*precise::tanh(mqk);
                         }
 
-                        mqk += sm[4*cc + ty]*slope;
+                        mqk += sm[NE*cc + ty]*slope;
 
-                        ss[4*cc + ty] = mqk;
+                        ss[NE*cc + ty] = mqk;
                     }
                 }
             }
@@ -3784,8 +3812,8 @@ kernel void kernel_flash_attn_ext_vec(
                 ss[tiisg] = vs;
 
                 // O = diag(ms)*O
-                #pragma unroll(D16/NL)
-                for (short ii = 0; ii < D16; ii += NL) {
+                #pragma unroll(DV4/NL)
+                for (short ii = 0; ii < DV4; ii += NL) {
                     lo[ii/NL] *= ms;
                 }
             }
@@ -3794,17 +3822,18 @@ kernel void kernel_flash_attn_ext_vec(
 
             // O = O + (Q*K^T)*V
             {
-                for (short cc = 0; cc < C/4; ++cc) {
-                    device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+                //#pragma unroll(C/NE)
+                for (short cc = 0; cc < C/NE; ++cc) {
+                    device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
 
-                    const s4x4_t ms(ss[4*cc + ty]);
+                    const s4_t ms(ss[NE*cc + ty]);
 
-                    #pragma unroll(D16/NL)
-                    for (short ii = 0; ii < D16; ii += NL) {
+                    #pragma unroll(DV4/NL)
+                    for (short ii = 0; ii < DV4; ii += NL) {
                         const short i = ii + tx;
 
-                        v4x4_t mv;
-                        deq_v(pv4 + i/nl_v, i%nl_v, mv);
+                        v4_t mv;
+                        deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
 
                         lo[ii/NL] += mv*ms;
                     }
@@ -3819,7 +3848,7 @@ kernel void kernel_flash_attn_ext_vec(
         }
     }
 
-    // simdgroup reduce
+    // simdgroup reduce (NE = 4)
     // [ 0,  8, 16, 24] -> [ 0]
     // [ 1,  9, 17, 25] -> [ 1]
     // [ 2, 10, 18, 26] -> [ 2]
@@ -3828,37 +3857,48 @@ kernel void kernel_flash_attn_ext_vec(
     // [ 5, 13, 21, 29] -> [ 5]
     // [ 6, 14, 22, 30] -> [ 6]
     // [ 7, 15, 23, 31] -> [ 7]
-    for (short ii = 0; ii < D16; ii += NL) {
-        lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
-        lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  8);
-      //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  4);
-      //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  2);
-      //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  1);
-
-        lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
-        lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  8);
-      //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  4);
-      //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  2);
-      //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  1);
-
-        lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
-        lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  8);
-      //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  4);
-      //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  2);
-      //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  1);
-
-        lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
-        lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  8);
-      //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  4);
-      //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  2);
-      //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  1);
+    for (short ii = 0; ii < DV4; ii += NL) {
+        if (NE > 1) {
+            lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
+            lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
+            lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
+            lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
+        }
+
+        if (NE > 2) {
+            lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  8);
+            lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  8);
+            lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  8);
+            lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  8);
+        }
+
+        if (NE > 4) {
+            lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  4);
+            lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  4);
+            lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  4);
+            lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  4);
+        }
+
+        if (NE > 8) {
+            lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  2);
+            lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  2);
+            lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  2);
+            lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  2);
+        }
+
+        if (NE > 16) {
+            lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  1);
+            lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  1);
+            lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  1);
+            lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  1);
+        }
     }
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     // store results to shared memory
-    for (short i = tiisg; i < D16; i += NL) {
-        sr4x4[i] = lo[i/NL];
+    for (short i = tiisg; i < DV4; i += NL) {
+        sr4[i] = lo[i/NL];
     }
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -3885,22 +3925,22 @@ kernel void kernel_flash_attn_ext_vec(
             }
 
             // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
-            for (short i = tiisg; i < D16; i += NW) {
-                sr4x4[i] = sr4x4[i]*ms0 + sr4x4[i + r*D16]*ms1;
+            for (short i = tiisg; i < DV4; i += NW) {
+                sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1;
             }
         }
 
         threadgroup_barrier(mem_flags::mem_threadgroup);
     }
 
-    device float4x4 * dst44 = (device float4x4 *) dst;
+    device float4 * dst4 = (device float4 *) dst;
 
     // final rescale with 1/S and store to global memory
     if (sgitg == 0) {
         const float S = ss[0];
 
-        for (short i = tiisg; i < D16; i += NW) {
-            dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
+        for (short i = tiisg; i < DV4; i += NW) {
+            dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S;
         }
     }
 }
@@ -3909,34 +3949,54 @@ kernel void kernel_flash_attn_ext_vec(
 //       in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
 //
 #define FA_TYPES \
-           half4,  half4x4, \
-                   half4x4, \
-                   half4x4, \
-    float,                  \
-    half,  half4,  half4x4, \
-                   half4x4
+           half4, \
+           half4, \
+           half4, \
+    float,        \
+    half,  half4, \
+           half4
 
-typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>) flash_attn_ext_vec_t;
+typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 128>) flash_attn_ext_vec_t;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,             1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  128, 128, 4>;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4,           1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 128, 128, 4>;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0,        8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 128, 128, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1,        8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 128, 128, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0,        8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 128, 128, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1,        8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 128, 128, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0,        8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 128, 128, 4>;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h192")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,             1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  192, 192, 4>;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4,           1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 192, 192, 4>;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0,        8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 192, 192, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1,        8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 192, 192, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0,        8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 192, 192, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1,        8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 192, 192, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0,        8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 192, 192, 4>;
 
-template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,     1, dequantize_f16,  128>;
+template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  192, 128, 4>;
 #if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,   1, dequantize_bf16, 128>;
+template [[host_name("kernel_flash_attn_ext_vec_bf16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 192, 128, 4>;
 #endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0,  2, dequantize_q4_0, 128>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1,  2, dequantize_q4_1, 128>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0,  2, dequantize_q5_0, 128>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1,  2, dequantize_q5_1, 128>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0,  2, dequantize_q8_0, 128>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 192, 128, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 192, 128, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 192, 128, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 192, 128, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 192, 128, 4>;
 
-template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,     1, dequantize_f16,  256>;
+template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,             1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  256, 256, 4>;
 #if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,   1, dequantize_bf16, 256>;
+template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4,           1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 256, 256, 4>;
 #endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0,  2, dequantize_q4_0, 256>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1,  2, dequantize_q4_1, 256>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0,  2, dequantize_q5_0, 256>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1,  2, dequantize_q5_1, 256>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0,  2, dequantize_q8_0, 256>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0,        8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 256, 256, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1,        8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 256, 256, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0,        8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 256, 256, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1,        8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 256, 256, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0,        8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 256, 256, 4>;
 
 #undef FA_TYPES
 

+ 4 - 0
ggml/src/ggml-vulkan/ggml-vulkan.cpp

@@ -8764,6 +8764,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 default:
                     return false;
                 }
+                if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
+                    // different head sizes of K and V are not supported yet
+                    return false;
+                }
                 if (op->src[0]->type != GGML_TYPE_F32) {
                     return false;
                 }

+ 1 - 1
ggml/src/ggml.c

@@ -4369,7 +4369,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
     }
 
     // permute(0, 2, 1, 3)
-    int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
+    int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
     float params[] = { scale, max_bias, logit_softcap };

+ 0 - 5
src/llama-context.cpp

@@ -2316,11 +2316,6 @@ llama_context * llama_init_from_model(
         params.flash_attn = false;
     }
 
-    if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
-        LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
-        params.flash_attn = false;
-    }
-
     if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
         LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
         return nullptr;

+ 38 - 31
tests/test-backend-ops.cpp

@@ -3217,7 +3217,8 @@ struct test_leaky_relu : public test_case {
 
 // GGML_OP_FLASH_ATTN_EXT
 struct test_flash_attn_ext : public test_case {
-    const int64_t hs; // head size
+    const int64_t hsk; // K head size
+    const int64_t hsv; // V head size
     const int64_t nh; // num heads
     const int64_t nr; // repeat in Q, tests for grouped-query attention
     const int64_t kv; // kv size
@@ -3233,7 +3234,7 @@ struct test_flash_attn_ext : public test_case {
     std::array<int32_t, 4> permute;
 
     std::string vars() override {
-        return VARS_TO_STR11(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
+        return VARS_TO_STR12(hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
     }
 
     double max_nmse_err() override {
@@ -3243,17 +3244,18 @@ struct test_flash_attn_ext : public test_case {
     uint64_t op_flops(ggml_tensor * t) override {
         GGML_UNUSED(t);
         // Just counting matmul costs:
-        // Q*K^T is nb x hs x kv, P*V is nb x kv x hs, per head
-        return 2 * 2 * nh*nr * nb * hs * kv;
+        // Q*K^T is nb x hsk x kv, P*V is nb x kv x hsv, per head
+        return 2 * nh*nr * nb * (hsk + hsv) * kv;
     }
 
-    test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
+    test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
                         bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
                         ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
-        : hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
+        : hsk(hsk), hsv(hsv), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
-        const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
+        const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
+        const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_KV));
 
         auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
             int64_t ne[4] = {ne0, ne1, ne2, ne3};
@@ -3268,13 +3270,13 @@ struct test_flash_attn_ext : public test_case {
             return t;
         };
 
-        ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh*nr, 1);
+        ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr, 1);
         ggml_set_name(q, "q");
 
-        ggml_tensor * k = create_permuted(type_KV,       hs_padded, kv, nh,    1);
+        ggml_tensor * k = create_permuted(type_KV,       hsk_padded, kv, nh,    1);
         ggml_set_name(k, "k");
 
-        ggml_tensor * v = create_permuted(type_KV,       hs_padded, kv, nh,    1);
+        ggml_tensor * v = create_permuted(type_KV,       hsv_padded, kv, nh,    1);
         ggml_set_name(v, "v");
 
         ggml_tensor * m = nullptr;
@@ -3283,7 +3285,7 @@ struct test_flash_attn_ext : public test_case {
             ggml_set_name(m, "m");
         }
 
-        ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap);
+        ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hsk), max_bias, logit_softcap);
         ggml_flash_attn_ext_set_prec(out, prec);
         ggml_set_name(out, "out");
 
@@ -4412,27 +4414,32 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_timestep_embedding());
     test_cases.emplace_back(new test_leaky_relu());
 
-    for (int hs : { 64, 80, 128, 256, }) {
-        for (bool mask : { true, false } ) {
-            for (float max_bias : { 0.0f, 8.0f }) {
-                if (!mask && max_bias > 0.0f) continue;
-                for (float logit_softcap : {0.0f, 10.0f}) {
-                    if (hs != 128 && logit_softcap != 0.0f) continue;
-                    for (int nh : { 4, }) {
-                        for (int nr : { 1, 4, 16 }) {
-                            if (nr == 16 && hs != 128) continue;
-                            for (int kv : { 512, 1024, }) {
-                                if (nr != 1 && kv != 512) continue;
-                                for (int nb : { 1, 3, 32, 35, }) {
-                                    for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
-                                        if (hs != 128 && prec == GGML_PREC_DEFAULT) continue;
-                                        for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
-                                            test_cases.emplace_back(new test_flash_attn_ext(
-                                                hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
-                                            // run fewer test cases permuted
-                                            if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+    for (int hsk : { 64, 80, 128, 192, 256, }) {
+        for (int hsv : { 64, 80, 128, 192, 256, }) {
+            if (hsk != 192 && hsk != hsv) continue;
+            if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
+
+            for (bool mask : { true, false } ) {
+                for (float max_bias : { 0.0f, 8.0f }) {
+                    if (!mask && max_bias > 0.0f) continue;
+                    for (float logit_softcap : {0.0f, 10.0f}) {
+                        if (hsk != 128 && logit_softcap != 0.0f) continue;
+                        for (int nh : { 4, }) {
+                            for (int nr : { 1, 4, 16 }) {
+                                if (nr == 16 && hsk != 128) continue;
+                                for (int kv : { 512, 1024, }) {
+                                    if (nr != 1 && kv != 512) continue;
+                                    for (int nb : { 1, 3, 32, 35, }) {
+                                        for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
+                                            if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
+                                            for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
                                                 test_cases.emplace_back(new test_flash_attn_ext(
-                                                    hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
+                                                    hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
+                                                // run fewer test cases permuted
+                                                if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+                                                    test_cases.emplace_back(new test_flash_attn_ext(
+                                                        hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
+                                                }
                                             }
                                         }
                                     }

Some files were not shown because too many files changed in this diff