Browse Source

ggml-cpu: Use tiled FA for prompt-processing (#19012)

* ggml-cpu: Use tiled FA for prompt-processing

the FA performance is gimped on CPU on long contexts because it essentially uses a vector kernel. This PR adds a tiled FA for PP. Perf tuning for tile sizes done on a AMD EPYC single-socket 64-c machine.

* fix out of bounds for mask

* skip rows where there are all masks

* skip tile if mask is inf

* store mask in worksize

* check inf tile earlier
Aman Gupta 2 days ago
parent
commit
bcb43163ae
3 changed files with 303 additions and 4 deletions
  1. 8 0
      ggml/src/ggml-cpu/common.h
  2. 6 3
      ggml/src/ggml-cpu/ggml-cpu.c
  3. 289 1
      ggml/src/ggml-cpu/ops.cpp

+ 8 - 0
ggml/src/ggml-cpu/common.h

@@ -6,6 +6,9 @@
 #include "ggml-impl.h"
 #include "simd-mappings.h"
 
+#define GGML_FA_TILE_Q  32
+#define GGML_FA_TILE_KV 16
+
 #ifdef __cplusplus
 
 #include <utility>
@@ -84,4 +87,9 @@ static std::pair<int64_t, int64_t> get_thread_range(const struct ggml_compute_pa
     return {ir0, ir1};
 }
 
+struct ggml_fa_tile_config {
+    static constexpr size_t Q  = GGML_FA_TILE_Q;
+    static constexpr size_t KV = GGML_FA_TILE_KV;
+};
+
 #endif

+ 6 - 3
ggml/src/ggml-cpu/ggml-cpu.c

@@ -14,6 +14,7 @@
 #include "vec.h"
 #include "ops.h"
 #include "ggml.h"
+#include "common.h"
 
 #if defined(_MSC_VER) || defined(__MINGW32__)
 #include <malloc.h> // using malloc.h with MSC/MINGW
@@ -2866,10 +2867,12 @@ struct ggml_cplan ggml_graph_plan(
                     } break;
                 case GGML_OP_FLASH_ATTN_EXT:
                     {
-                        const int64_t ne10 = node->src[1]->ne[0]; // DK
-                        const int64_t ne20 = node->src[2]->ne[0]; // DV
+                        const int64_t DK = node->src[1]->ne[0];
+                        const int64_t DV = node->src[2]->ne[0];
 
-                        cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
+                        // Tiled flash attention scratch (tile sizes defined in common.h)
+                        // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding
+                        cur = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
                     } break;
                 case GGML_OP_FLASH_ATTN_BACK:
                     {

+ 289 - 1
ggml/src/ggml-cpu/ops.cpp

@@ -8164,6 +8164,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
         // online softmax / attention
         // loop over n_kv and n_head_kv
         // ref: https://arxiv.org/pdf/2112.05682.pdf
+
         for (int64_t ic = 0; ic < nek1; ++ic) {
             const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
             if (mv == -INFINITY) {
@@ -8271,6 +8272,280 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
     }
 }
 
+static void ggml_compute_forward_flash_attn_ext_tiled(
+        const ggml_compute_params * params,
+        ggml_tensor * dst,
+        int ir0, int ir1) {
+    const ggml_tensor * q     = dst->src[0];
+    const ggml_tensor * k     = dst->src[1];
+    const ggml_tensor * v     = dst->src[2];
+    const ggml_tensor * mask  = dst->src[3];
+    const ggml_tensor * sinks = dst->src[4];
+
+    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
+
+    const int64_t DK = nek0;
+    const int64_t DV = nev0;
+    const int64_t N  = neq1;
+
+    GGML_ASSERT(ne0 == DV);
+    GGML_ASSERT(ne2 == N);
+
+    // input tensor rows must be contiguous
+    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
+    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
+    GGML_ASSERT(nbv0 == ggml_type_size(v->type));
+
+    GGML_ASSERT(neq0 == DK);
+    GGML_ASSERT(nek0 == DK);
+    GGML_ASSERT(nev0 == DV);
+
+    GGML_ASSERT(neq1 == N);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(k->type == v->type);
+    const ggml_type kv_type = k->type;
+
+    const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type);
+    const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float;
+    const ggml_vec_dot_t    kv_vec_dot    = kv_type_traits_cpu->vec_dot;
+    const size_t kv_type_size = ggml_type_size(kv_type);
+
+    // broadcast factors
+    const int64_t rk2 = neq2/nek2;
+    const int64_t rk3 = neq3/nek3;
+
+    const int64_t rv2 = neq2/nev2;
+    const int64_t rv3 = neq3/nev3;
+
+    float scale         = 1.0f;
+    float max_bias      = 0.0f;
+    float logit_softcap = 0.0f;
+
+    memcpy(&scale,         (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias,      (float *) dst->op_params + 1, sizeof(float));
+    memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
+
+    if (logit_softcap != 0) {
+        scale /= logit_softcap;
+    }
+
+    const uint32_t n_head      = neq2;
+    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    int ith = params->ith;
+
+    static constexpr int Q_TILE_SZ  = ggml_fa_tile_config::Q;
+    static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
+
+    GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ");
+
+    int ir = ir0;
+    while (ir < ir1) {
+        // q indices for the start of this tile
+        const int iq3 = ir/(neq2*neq1);
+        const int iq2 = (ir - iq3*neq2*neq1)/neq1;
+        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
+
+        // Number of valid rows in this tile:
+        // - limited by tile size (Q_TILE_SZ)
+        // - limited by chunk boundary (ir1 - ir)
+        // - limited by head boundary (neq1 - iq1) to avoid crossing into next head
+        const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1)));
+        GGML_ASSERT(tile_rows > 0);
+
+        const uint32_t h = iq2; // head index
+        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
+
+        float S[Q_TILE_SZ];
+        float M[Q_TILE_SZ];
+
+        for (int i = 0 ; i < Q_TILE_SZ; ++i) {
+            S[i] = 0.;
+            M[i] = -INFINITY;
+        }
+
+        // Per-thread scratch layout:
+        // Q_q:    Q_TILE_SZ * DK (converted Q tile in KV type)
+        // KQ:     Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
+        // mask:   Q_TILE_SZ * KV_TILE_SZ (mask in float)
+        // VKQ32:  Q_TILE_SZ * DV (FP32 output accumulator)
+        // V32:    KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
+        float * base  = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32);
+
+        void  * Q_q    = base;
+        float * KQ     = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
+        float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
+        float * VKQ32  = mask32 + Q_TILE_SZ * KV_TILE_SZ;
+        float * V32    = VKQ32 + Q_TILE_SZ * DV;  // F32 buffer for V tile
+
+        memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
+        memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
+
+        // k indices
+        const int ik3 = iq3 / rk3;
+        const int ik2 = iq2 / rk2;
+
+        // v indices
+        const int iv3 = iq3 / rv3;
+        const int iv2 = iq2 / rv2;
+
+        for (int tq = 0; tq < tile_rows; tq++) {
+            const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
+            kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK);
+        }
+        // Zero-pad remaining rows
+        for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
+            memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size);
+        }
+
+        for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
+
+            // skip the tile entirely if all the masks are -inf
+            if (mask) {
+                bool can_skip = true;
+                for (int tq = 0; tq < tile_rows; tq++) {
+                    const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
+                    for (int tk = 0; tk < KV_TILE_SZ; tk++) {
+                        mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
+                        if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
+                            can_skip = false;
+                        }
+                    }
+                }
+
+                if (can_skip) {
+                    continue;
+                }
+            }
+
+            for (int tq = 0; tq < Q_TILE_SZ; tq++) {
+                const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
+                for (int tk = 0; tk < KV_TILE_SZ; tk++) {
+                    const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
+                    float s;
+                    kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1);
+                    KQ[tq * KV_TILE_SZ + tk] = s * scale;
+                }
+            }
+
+            if (logit_softcap != 0.0f) {
+                ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ);
+                ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap);
+            }
+
+            if (mask) {
+                ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32);
+            }
+
+            bool skip[Q_TILE_SZ] = {};
+
+            for (int tq = 0; tq < Q_TILE_SZ; tq++) {
+                float * kq_row = KQ + tq * KV_TILE_SZ;
+
+                float tile_max;
+                ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row);
+
+                if (tile_max == -INFINITY) {
+                    skip[tq] = true;
+                    continue;
+                }
+
+                const float Mold = M[tq];
+                const float Mnew = fmaxf(Mold, tile_max);
+
+                if (Mnew > Mold) {
+                    const float ms = expf(Mold - Mnew);
+                    ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
+                    S[tq] *= ms;
+                }
+                M[tq] = Mnew;
+
+
+                S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
+            }
+
+            // Convert V tile to F32 first (if F16), then do MAD
+            // On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
+            // TODO: on ARM, native f16 should be faster
+            if (kv_type == GGML_TYPE_F16) {
+                for (int tk = 0; tk < KV_TILE_SZ; tk++) {
+                    const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
+                    ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV);
+                }
+                for (int tq = 0; tq < Q_TILE_SZ; tq++) {
+                    if (skip[tq]) continue;
+                    float * vkq_row = VKQ32 + tq * DV;
+                    for (int tk = 0; tk < KV_TILE_SZ; tk++) {
+                        const float p = KQ[tq * KV_TILE_SZ + tk];
+                        ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p);
+                    }
+                }
+            } else {
+                for (int tq = 0; tq < Q_TILE_SZ; tq++) {
+                    if (skip[tq]) continue;
+                    float * vkq_row = VKQ32 + tq * DV;
+                    for (int tk = 0; tk < KV_TILE_SZ; tk++) {
+                        const float p = KQ[tq * KV_TILE_SZ + tk];
+                        const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
+                        ggml_vec_mad_f32(DV, vkq_row, v_row, p);
+                    }
+                }
+            }
+        }
+
+        // sinks (apply only to valid rows in the tile)
+        if (sinks) {
+            const float s = ((float *)((char *) sinks->data))[h];
+
+            for (int tq = 0; tq < tile_rows; tq++) {
+                float ms = 1.0f;
+                float vs = 1.0f;
+
+                if (s > M[tq]) {
+                    ms = expf(M[tq] - s);
+                    ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
+                } else {
+                    vs = expf(s - M[tq]);
+                }
+
+                S[tq] = S[tq] * ms + vs;
+            }
+        }
+
+        for (int tq = 0; tq < tile_rows; tq++) {
+            // V /= S
+            const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq];
+            ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv);
+
+            // dst indices
+            const int i1 = iq1 + tq;
+            const int i2 = iq2;
+            const int i3 = iq3;
+
+            // permute(0, 2, 1, 3)
+            memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1);
+        }
+
+        ir += tile_rows;
+    }
+}
+
 static void ggml_compute_forward_flash_attn_ext_f16(
         const ggml_compute_params * params,
         ggml_tensor * dst) {
@@ -8343,6 +8618,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
     // The number of elements in each chunk
     const int64_t dr = (nr + nchunk - 1) / nchunk;
 
+    static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
+    static constexpr int64_t Q_TILE_SZ  = ggml_fa_tile_config::Q;
+    const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
+    const bool use_tiled = (q->type == GGML_TYPE_F32 &&
+                            kv_is_f32_or_f16 &&
+                            k->type == v->type &&
+                            nek1 % KV_TILE_SZ == 0 &&
+                            neq1 >= Q_TILE_SZ);  // Only use tiled for batch >= tile size
+
     // The first chunk comes from our thread_id, the rest will get auto-assigned.
     int current_chunk = ith;
 
@@ -8350,7 +8634,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
         const int64_t ir0 = dr * current_chunk;
         const int64_t ir1 = MIN(ir0 + dr, nr);
 
-        ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
+        if (use_tiled) {
+            ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
+        } else {
+            ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
+        }
 
         current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
     }