|
|
@@ -8164,6 +8164,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
|
// online softmax / attention
|
|
|
// loop over n_kv and n_head_kv
|
|
|
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
|
|
+
|
|
|
for (int64_t ic = 0; ic < nek1; ++ic) {
|
|
|
const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
|
|
|
if (mv == -INFINITY) {
|
|
|
@@ -8271,6 +8272,280 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+static void ggml_compute_forward_flash_attn_ext_tiled(
|
|
|
+ const ggml_compute_params * params,
|
|
|
+ ggml_tensor * dst,
|
|
|
+ int ir0, int ir1) {
|
|
|
+ const ggml_tensor * q = dst->src[0];
|
|
|
+ const ggml_tensor * k = dst->src[1];
|
|
|
+ const ggml_tensor * v = dst->src[2];
|
|
|
+ const ggml_tensor * mask = dst->src[3];
|
|
|
+ const ggml_tensor * sinks = dst->src[4];
|
|
|
+
|
|
|
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
|
|
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
|
|
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
|
|
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
|
|
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
|
|
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
|
|
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
|
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
|
+
|
|
|
+ const int64_t DK = nek0;
|
|
|
+ const int64_t DV = nev0;
|
|
|
+ const int64_t N = neq1;
|
|
|
+
|
|
|
+ GGML_ASSERT(ne0 == DV);
|
|
|
+ GGML_ASSERT(ne2 == N);
|
|
|
+
|
|
|
+ // input tensor rows must be contiguous
|
|
|
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
|
|
|
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
|
|
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
|
|
+
|
|
|
+ GGML_ASSERT(neq0 == DK);
|
|
|
+ GGML_ASSERT(nek0 == DK);
|
|
|
+ GGML_ASSERT(nev0 == DV);
|
|
|
+
|
|
|
+ GGML_ASSERT(neq1 == N);
|
|
|
+
|
|
|
+ // dst cannot be transposed or permuted
|
|
|
+ GGML_ASSERT(nb0 == sizeof(float));
|
|
|
+ GGML_ASSERT(nb0 <= nb1);
|
|
|
+ GGML_ASSERT(nb1 <= nb2);
|
|
|
+ GGML_ASSERT(nb2 <= nb3);
|
|
|
+
|
|
|
+ GGML_ASSERT(k->type == v->type);
|
|
|
+ const ggml_type kv_type = k->type;
|
|
|
+
|
|
|
+ const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type);
|
|
|
+ const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float;
|
|
|
+ const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot;
|
|
|
+ const size_t kv_type_size = ggml_type_size(kv_type);
|
|
|
+
|
|
|
+ // broadcast factors
|
|
|
+ const int64_t rk2 = neq2/nek2;
|
|
|
+ const int64_t rk3 = neq3/nek3;
|
|
|
+
|
|
|
+ const int64_t rv2 = neq2/nev2;
|
|
|
+ const int64_t rv3 = neq3/nev3;
|
|
|
+
|
|
|
+ float scale = 1.0f;
|
|
|
+ float max_bias = 0.0f;
|
|
|
+ float logit_softcap = 0.0f;
|
|
|
+
|
|
|
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
|
|
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
|
|
+ memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
|
|
|
+
|
|
|
+ if (logit_softcap != 0) {
|
|
|
+ scale /= logit_softcap;
|
|
|
+ }
|
|
|
+
|
|
|
+ const uint32_t n_head = neq2;
|
|
|
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
|
|
+
|
|
|
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
|
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
|
+
|
|
|
+ int ith = params->ith;
|
|
|
+
|
|
|
+ static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
|
|
|
+ static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
|
|
|
+
|
|
|
+ GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ");
|
|
|
+
|
|
|
+ int ir = ir0;
|
|
|
+ while (ir < ir1) {
|
|
|
+ // q indices for the start of this tile
|
|
|
+ const int iq3 = ir/(neq2*neq1);
|
|
|
+ const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
|
|
+ const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
|
|
+
|
|
|
+ // Number of valid rows in this tile:
|
|
|
+ // - limited by tile size (Q_TILE_SZ)
|
|
|
+ // - limited by chunk boundary (ir1 - ir)
|
|
|
+ // - limited by head boundary (neq1 - iq1) to avoid crossing into next head
|
|
|
+ const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1)));
|
|
|
+ GGML_ASSERT(tile_rows > 0);
|
|
|
+
|
|
|
+ const uint32_t h = iq2; // head index
|
|
|
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
|
|
+
|
|
|
+ float S[Q_TILE_SZ];
|
|
|
+ float M[Q_TILE_SZ];
|
|
|
+
|
|
|
+ for (int i = 0 ; i < Q_TILE_SZ; ++i) {
|
|
|
+ S[i] = 0.;
|
|
|
+ M[i] = -INFINITY;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Per-thread scratch layout:
|
|
|
+ // Q_q: Q_TILE_SZ * DK (converted Q tile in KV type)
|
|
|
+ // KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
|
|
|
+ // mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
|
|
|
+ // VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
|
|
|
+ // V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
|
|
|
+ float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32);
|
|
|
+
|
|
|
+ void * Q_q = base;
|
|
|
+ float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
|
|
|
+ float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
|
|
|
+ float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
|
|
|
+ float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile
|
|
|
+
|
|
|
+ memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
|
|
|
+ memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
|
|
|
+
|
|
|
+ // k indices
|
|
|
+ const int ik3 = iq3 / rk3;
|
|
|
+ const int ik2 = iq2 / rk2;
|
|
|
+
|
|
|
+ // v indices
|
|
|
+ const int iv3 = iq3 / rv3;
|
|
|
+ const int iv2 = iq2 / rv2;
|
|
|
+
|
|
|
+ for (int tq = 0; tq < tile_rows; tq++) {
|
|
|
+ const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
|
|
|
+ kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK);
|
|
|
+ }
|
|
|
+ // Zero-pad remaining rows
|
|
|
+ for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
|
|
|
+ memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size);
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
|
|
|
+
|
|
|
+ // skip the tile entirely if all the masks are -inf
|
|
|
+ if (mask) {
|
|
|
+ bool can_skip = true;
|
|
|
+ for (int tq = 0; tq < tile_rows; tq++) {
|
|
|
+ const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
|
|
|
+ for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
|
|
+ mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
|
|
|
+ if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
|
|
|
+ can_skip = false;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if (can_skip) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
|
|
+ const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
|
|
|
+ for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
|
|
+ const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
|
|
|
+ float s;
|
|
|
+ kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1);
|
|
|
+ KQ[tq * KV_TILE_SZ + tk] = s * scale;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if (logit_softcap != 0.0f) {
|
|
|
+ ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ);
|
|
|
+ ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (mask) {
|
|
|
+ ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32);
|
|
|
+ }
|
|
|
+
|
|
|
+ bool skip[Q_TILE_SZ] = {};
|
|
|
+
|
|
|
+ for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
|
|
+ float * kq_row = KQ + tq * KV_TILE_SZ;
|
|
|
+
|
|
|
+ float tile_max;
|
|
|
+ ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row);
|
|
|
+
|
|
|
+ if (tile_max == -INFINITY) {
|
|
|
+ skip[tq] = true;
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ const float Mold = M[tq];
|
|
|
+ const float Mnew = fmaxf(Mold, tile_max);
|
|
|
+
|
|
|
+ if (Mnew > Mold) {
|
|
|
+ const float ms = expf(Mold - Mnew);
|
|
|
+ ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
|
|
|
+ S[tq] *= ms;
|
|
|
+ }
|
|
|
+ M[tq] = Mnew;
|
|
|
+
|
|
|
+
|
|
|
+ S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Convert V tile to F32 first (if F16), then do MAD
|
|
|
+ // On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
|
|
|
+ // TODO: on ARM, native f16 should be faster
|
|
|
+ if (kv_type == GGML_TYPE_F16) {
|
|
|
+ for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
|
|
+ const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
|
|
|
+ ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV);
|
|
|
+ }
|
|
|
+ for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
|
|
+ if (skip[tq]) continue;
|
|
|
+ float * vkq_row = VKQ32 + tq * DV;
|
|
|
+ for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
|
|
+ const float p = KQ[tq * KV_TILE_SZ + tk];
|
|
|
+ ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
|
|
+ if (skip[tq]) continue;
|
|
|
+ float * vkq_row = VKQ32 + tq * DV;
|
|
|
+ for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
|
|
+ const float p = KQ[tq * KV_TILE_SZ + tk];
|
|
|
+ const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
|
|
|
+ ggml_vec_mad_f32(DV, vkq_row, v_row, p);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // sinks (apply only to valid rows in the tile)
|
|
|
+ if (sinks) {
|
|
|
+ const float s = ((float *)((char *) sinks->data))[h];
|
|
|
+
|
|
|
+ for (int tq = 0; tq < tile_rows; tq++) {
|
|
|
+ float ms = 1.0f;
|
|
|
+ float vs = 1.0f;
|
|
|
+
|
|
|
+ if (s > M[tq]) {
|
|
|
+ ms = expf(M[tq] - s);
|
|
|
+ ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
|
|
|
+ } else {
|
|
|
+ vs = expf(s - M[tq]);
|
|
|
+ }
|
|
|
+
|
|
|
+ S[tq] = S[tq] * ms + vs;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int tq = 0; tq < tile_rows; tq++) {
|
|
|
+ // V /= S
|
|
|
+ const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq];
|
|
|
+ ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv);
|
|
|
+
|
|
|
+ // dst indices
|
|
|
+ const int i1 = iq1 + tq;
|
|
|
+ const int i2 = iq2;
|
|
|
+ const int i3 = iq3;
|
|
|
+
|
|
|
+ // permute(0, 2, 1, 3)
|
|
|
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1);
|
|
|
+ }
|
|
|
+
|
|
|
+ ir += tile_rows;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
const ggml_compute_params * params,
|
|
|
ggml_tensor * dst) {
|
|
|
@@ -8343,6 +8618,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
// The number of elements in each chunk
|
|
|
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
|
|
|
|
|
+ static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
|
|
|
+ static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
|
|
|
+ const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
|
|
|
+ const bool use_tiled = (q->type == GGML_TYPE_F32 &&
|
|
|
+ kv_is_f32_or_f16 &&
|
|
|
+ k->type == v->type &&
|
|
|
+ nek1 % KV_TILE_SZ == 0 &&
|
|
|
+ neq1 >= Q_TILE_SZ); // Only use tiled for batch >= tile size
|
|
|
+
|
|
|
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
|
|
int current_chunk = ith;
|
|
|
|
|
|
@@ -8350,7 +8634,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
const int64_t ir0 = dr * current_chunk;
|
|
|
const int64_t ir1 = MIN(ir0 + dr, nr);
|
|
|
|
|
|
- ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
|
|
|
+ if (use_tiled) {
|
|
|
+ ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
|
|
|
+ } else {
|
|
|
+ ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
|
|
|
+ }
|
|
|
|
|
|
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
|
|
}
|