|
|
@@ -2,9 +2,9 @@
|
|
|
#pragma clang diagnostic ignored "-Wunused-function"
|
|
|
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
|
|
|
|
|
+#include <assert.h>
|
|
|
#include <HAP_farf.h>
|
|
|
#include <HAP_perf.h>
|
|
|
-
|
|
|
#include <math.h>
|
|
|
#include <string.h>
|
|
|
|
|
|
@@ -111,7 +111,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
|
|
|
hvx_vec_store_u(r, 4, rsum);
|
|
|
}
|
|
|
|
|
|
-// MAD: y (F32) += x (F16) * v (float)
|
|
|
+// MAD: y (F32) += x (F16) * s (float)
|
|
|
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
|
|
|
const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
|
|
|
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
|
|
|
@@ -318,9 +318,12 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|
|
uint32_t ic = 0;
|
|
|
|
|
|
// Process in blocks of 32 (VLEN_FP32)
|
|
|
- for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) {
|
|
|
+ static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 == 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
|
|
|
+ HVX_Vector_x4 scores_x4;
|
|
|
+ HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
|
|
|
+ for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
|
|
|
// 1. Compute scores
|
|
|
- float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
|
|
|
+ float __attribute__((aligned(VLEN))) scores_arr[FLASH_ATTN_BLOCK_SIZE];
|
|
|
for (int j = 0; j < VLEN_FP32; ++j) {
|
|
|
const uint32_t cur_ic = ic + j;
|
|
|
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
|
|
|
@@ -356,36 +359,43 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|
|
scores = Q6_Vsf_equals_Vqf32(scores);
|
|
|
}
|
|
|
|
|
|
+ scores_x4.v[iv] = scores;
|
|
|
+ v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max);
|
|
|
+ }
|
|
|
+
|
|
|
+ {
|
|
|
// 4. Online Softmax Update
|
|
|
- HVX_Vector v_max = hvx_vec_reduce_max_f32(scores);
|
|
|
+ v_max = hvx_vec_reduce_max_f32(v_max);
|
|
|
float m_block = hvx_vec_get_f32(v_max);
|
|
|
-
|
|
|
float M_old = M;
|
|
|
float M_new = (m_block > M) ? m_block : M;
|
|
|
M = M_new;
|
|
|
|
|
|
- float ms = expf(M_old - M_new);
|
|
|
-
|
|
|
+ const float ms = expf(M_old - M_new);
|
|
|
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
|
|
- S = S * ms;
|
|
|
|
|
|
HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new);
|
|
|
- HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
|
|
|
- HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
|
|
|
-
|
|
|
- HVX_Vector p_sum_vec = hvx_vec_reduce_sum_f32(P);
|
|
|
- float p_sum = hvx_vec_get_f32(p_sum_vec);
|
|
|
- S += p_sum;
|
|
|
-
|
|
|
- // 5. Accumulate V
|
|
|
- float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
|
|
|
- *(HVX_Vector*)p_arr = P;
|
|
|
-
|
|
|
- for (int j = 0; j < VLEN_FP32; ++j) {
|
|
|
- const uint32_t cur_ic = ic + j;
|
|
|
- const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
|
|
|
- hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
|
|
|
+ HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
|
|
|
+ for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
|
|
|
+ HVX_Vector scores = scores_x4.v[iv];
|
|
|
+ HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
|
|
|
+ HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
|
|
|
+
|
|
|
+ p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
|
|
|
+
|
|
|
+ // 5. Accumulate V
|
|
|
+ float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
|
|
|
+ *(HVX_Vector*)p_arr = P;
|
|
|
+
|
|
|
+ for (int j = 0; j < VLEN_FP32; ++j) {
|
|
|
+ const uint32_t cur_ic = ic2 + j;
|
|
|
+ const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
|
|
|
+ hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
|
|
|
+ }
|
|
|
}
|
|
|
+
|
|
|
+ p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
|
|
|
+ S = S * ms + hvx_vec_get_f32(p_sum_vec);
|
|
|
}
|
|
|
|
|
|
// Leftover
|