Browse Source

ggml-hexagon: flash-attn opt (#19025)

* optimize flash attention kernel by improving score computation and online softmax update

* wip

* Refactor online softmax update in flash attention kernel for improved performance

* Optimize flash attention kernel by replacing float array with HVX_Vector for score computation

* wip
nullname 6 days ago
parent
commit
8af1f5f430
1 changed files with 34 additions and 24 deletions
  1. 34 24
      ggml/src/ggml-hexagon/htp/flash-attn-ops.c

+ 34 - 24
ggml/src/ggml-hexagon/htp/flash-attn-ops.c

@@ -2,9 +2,9 @@
 #pragma clang diagnostic ignored "-Wunused-function"
 #pragma clang diagnostic ignored "-Wunused-but-set-variable"
 
+#include <assert.h>
 #include <HAP_farf.h>
 #include <HAP_perf.h>
-
 #include <math.h>
 #include <string.h>
 
@@ -111,7 +111,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
     hvx_vec_store_u(r, 4, rsum);
 }
 
-// MAD: y (F32) += x (F16) * v (float)
+// MAD: y (F32) += x (F16) * s (float)
 static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
     const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
     HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
@@ -318,9 +318,12 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
             uint32_t ic = 0;
 
             // Process in blocks of 32 (VLEN_FP32)
-            for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) {
+            static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 == 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
+            HVX_Vector_x4 scores_x4;
+            HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
+            for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
                 // 1. Compute scores
-                float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
+                float __attribute__((aligned(VLEN))) scores_arr[FLASH_ATTN_BLOCK_SIZE];
                 for (int j = 0; j < VLEN_FP32; ++j) {
                     const uint32_t cur_ic = ic + j;
                     const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
@@ -356,36 +359,43 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
                     scores = Q6_Vsf_equals_Vqf32(scores);
                 }
 
+                scores_x4.v[iv] = scores;
+                v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max);
+            }
+
+            {
                 // 4. Online Softmax Update
-                HVX_Vector v_max = hvx_vec_reduce_max_f32(scores);
+                v_max = hvx_vec_reduce_max_f32(v_max);
                 float m_block = hvx_vec_get_f32(v_max);
-
                 float M_old = M;
                 float M_new = (m_block > M) ? m_block : M;
                 M = M_new;
 
-                float ms = expf(M_old - M_new);
-
+                const float ms = expf(M_old - M_new);
                 hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
-                S = S * ms;
 
                 HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new);
-                HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
-                HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
-
-                HVX_Vector p_sum_vec = hvx_vec_reduce_sum_f32(P);
-                float p_sum = hvx_vec_get_f32(p_sum_vec);
-                S += p_sum;
-
-                // 5. Accumulate V
-                float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
-                *(HVX_Vector*)p_arr = P;
-
-                for (int j = 0; j < VLEN_FP32; ++j) {
-                    const uint32_t cur_ic = ic + j;
-                    const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
-                    hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
+                HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
+                for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
+                    HVX_Vector scores = scores_x4.v[iv];
+                    HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
+                    HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
+
+                    p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
+
+                    // 5. Accumulate V
+                    float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
+                    *(HVX_Vector*)p_arr = P;
+
+                    for (int j = 0; j < VLEN_FP32; ++j) {
+                        const uint32_t cur_ic = ic2 + j;
+                        const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
+                        hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
+                    }
                 }
+
+                p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
+                S = S * ms + hvx_vec_get_f32(p_sum_vec);
             }
 
             // Leftover