Jelajahi Sumber

ggml-hexagon: fix swiglu failure at `test-backend-ops` (#17344)

* refactor: use hvx_vec_exp_fp32_guard_inf for overflow handling in hvx_exp_f32

* feat: add fast sigmoid function with overflow guard for fp32

* refactor: replace hvx_vec_inverse_fp32 with hvx_vec_inverse_fp32_guard_inf for improved overflow handling

* feat: enhance hvx_add_scalar_f32 with overflow handling using infinity guard

* wip

* add HVX_Vector_Alias

wip

* wip

* fix: improve handling of src1 tensor in glu_swiglu_fp32_per_thread function

* fix nc

* wip

* wip

* handle nan at inverse

* wip

* fix neg

* wip

* rename

* fix hvx_vec_inverse_fp32_guard_inf to handle infinity and NaN cases correctly

* wip

* fix hvx_vec_inverse_fp32_guard_inf to handle NaN cases correctly

* wip

* wip

* wip

* fix output sign
nullname 1 bulan lalu
induk
melakukan
21d31e0810

+ 12 - 18
ggml/src/ggml-hexagon/htp/act-ops.c

@@ -106,33 +106,32 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
     t1 = HAP_perf_get_qtimer_count();
 
     int is_aligned = 1;
-    int opt_path   = 0;
     if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
         is_aligned = 0;
         FARF(HIGH, "swiglu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
     }
-    if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
-        opt_path = 1;
-    }
 
     const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
     const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
     uint8_t * restrict data_dst        = (uint8_t *) dst->data;
 
-    bool src1_valid = src1->ne[0];
+    const bool src1_valid = src1->ne[0];
+    const int  nc         = (src1_valid) ? ne00 : ne00 / 2;
     if (!src1_valid) {
-        data_src1     = data_src0;
-        src1_row_size = src0_row_size;
+        const int32_t swapped = op_params[1];
+        data_src1             = data_src0;
+        src1_row_size         = src0_row_size;
+
+        const size_t nc_in_bytes = nc * SIZEOF_FP32;
+        data_src0 += swapped ? nc_in_bytes : 0;
+        data_src1 += swapped ? 0 : nc_in_bytes;
     }
 
     uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
     uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size);
     uint8_t * restrict dst_spad_data  = dst_spad->data + (ith * dst_row_size);
 
-    const int32_t swapped = op_params[1];
-
-    const int nc = (src1_valid) ? ne0 : ne0 / 2;
-
+    const bool opt_path = ((1 == is_aligned) && !(nb01 & (VLEN - 1)));
     for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
         const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
         const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size));
@@ -142,12 +141,7 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
             htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
         }
 
-        if (!src1_valid) {
-            src0 += swapped ? nc : 0;
-            src1 += swapped ? 0 : nc;
-        }
-
-        if (1 == opt_path) {
+        if (opt_path) {
             hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, nc);
             hvx_mul_mul_f32_opt((const uint8_t *) src0, (const uint8_t *) src0_spad_data, (const uint8_t *) src1,
                                 (uint8_t *) dst, nc);
@@ -218,7 +212,7 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
     const float   alpha   = ((const float *) (op_params))[2];
     const float   limit   = ((const float *) (op_params))[3];
 
-    const int nc = (src1_valid) ? ne0 : ne0 / 2;
+    const int nc = (src1_valid) ? ne00 : ne00 / 2;
 
     for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
         const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));

+ 19 - 6
ggml/src/ggml-hexagon/htp/hvx-exp.c

@@ -16,6 +16,19 @@
 #include "hvx-utils.h"
 #include "ops-utils.h"
 
+static inline HVX_Vector hvx_vec_exp_fp32_guard(HVX_Vector in_vec) {
+    static const float kInf    = INFINITY;
+    static const float kMaxExp = 88.02f;  // log(INF)
+
+    const HVX_Vector     max_exp = hvx_vec_splat_fp32(kMaxExp);
+    const HVX_Vector     inf     = hvx_vec_splat_fp32(kInf);
+    const HVX_VectorPred pred0   = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp);
+
+    HVX_Vector out = hvx_vec_exp_fp32(in_vec);
+
+    return Q6_V_vmux_QVV(pred0, inf, out);
+}
+
 void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) {
     int left_over       = num_elems & (VLEN_FP32 - 1);
     int num_elems_whole = num_elems - left_over;
@@ -42,9 +55,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int
         for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
             if (true == negate) {
                 HVX_Vector neg_vec_in = hvx_vec_neg_fp32(*p_vec_in1++);
-                *p_vec_out++          = hvx_vec_exp_fp32(neg_vec_in);
+                *p_vec_out++          = hvx_vec_exp_fp32_guard(neg_vec_in);
             } else {
-                *p_vec_out++ = hvx_vec_exp_fp32(*p_vec_in1++);
+                *p_vec_out++ = hvx_vec_exp_fp32_guard(*p_vec_in1++);
             }
         }
     } else {
@@ -54,9 +67,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int
 
             if (true == negate) {
                 HVX_Vector neg_vec_in                    = hvx_vec_neg_fp32(in);
-                *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32(neg_vec_in);
+                *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(neg_vec_in);
             } else {
-                *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32(in);
+                *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(in);
             }
         }
     }
@@ -70,9 +83,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int
         if (true == negate) {
             HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in);
 
-            vec_out = hvx_vec_exp_fp32(neg_vec_in);
+            vec_out = hvx_vec_exp_fp32_guard(neg_vec_in);
         } else {
-            vec_out = hvx_vec_exp_fp32(in);
+            vec_out = hvx_vec_exp_fp32_guard(in);
         }
 
         hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out);

+ 3 - 3
ggml/src/ggml-hexagon/htp/hvx-inverse.c

@@ -38,13 +38,13 @@ void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const
 
         #pragma unroll(4)
         for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            *p_vec_out++ = hvx_vec_inverse_fp32(*p_vec_in++);
+            *p_vec_out++ = hvx_vec_inverse_fp32_guard(*p_vec_in++);
         }
     } else {
         #pragma unroll(4)
         for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
             HVX_Vector in                            = *(HVX_UVector *) (src + i * SIZEOF_FP32);
-            *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32(in);
+            *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32_guard(in);
         }
     }
 
@@ -53,7 +53,7 @@ void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const
         float *       dstf = (float *) dst + num_elems_whole;
 
         HVX_Vector in  = *(HVX_UVector *) srcf;
-        HVX_Vector out = hvx_vec_inverse_fp32(in);
+        HVX_Vector out = hvx_vec_inverse_fp32_guard(in);
 
         hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out);
     }

+ 20 - 7
ggml/src/ggml-hexagon/htp/hvx-utils.c

@@ -401,7 +401,9 @@ void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t *
         FARF(HIGH, "hvx_add_scalar_f32: unaligned loop in hvx op, possibly slower execution\n");
     }
 
-    HVX_Vector val_vec = hvx_vec_splat_fp32(val);
+    static const float kInf    = INFINITY;
+    const HVX_Vector   inf     = hvx_vec_splat_fp32(kInf);
+    HVX_Vector         val_vec = hvx_vec_splat_fp32(val);
 
     if (0 == unaligned_loop) {
         HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
@@ -409,17 +411,24 @@ void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t *
 
         #pragma unroll(4)
         for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(*vec_in1++, val_vec);
-            *vec_out++   = Q6_Vsf_equals_Vqf32(v);
+            HVX_Vector           in       = *vec_in1++;
+            const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, in);
+            HVX_Vector           v        = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
+            v                             = Q6_Vsf_equals_Vqf32(v);
+            v                             = Q6_V_vmux_QVV(pred_inf, inf, v);
+            *vec_out++                    = v;
         }
     } else {
         #pragma unroll(4)
         for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
             HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
 
-            HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
+            const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, in);
+            HVX_Vector           out      = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
+            out                           = Q6_Vsf_equals_Vqf32(out);
+            out                           = Q6_V_vmux_QVV(pred_inf, inf, out);
 
-            *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
+            *(HVX_UVector *) (dst + i * SIZEOF_FP32) = out;
         }
     }
 
@@ -429,8 +438,12 @@ void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t *
 
         HVX_Vector in = *(HVX_UVector *) srcf;
 
-        HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
-        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
+        const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, in);
+        HVX_Vector           out      = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
+        out                           = Q6_Vsf_equals_Vqf32(out);
+        out                           = Q6_V_vmux_QVV(pred_inf, inf, out);
+
+        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out);
     }
 }
 

+ 45 - 11
ggml/src/ggml-hexagon/htp/hvx-utils.h

@@ -12,6 +12,15 @@
 #define VLEN_FP32   (VLEN / SIZEOF_FP32)
 #define VLEN_FP16   (VLEN / SIZEOF_FP16)
 
+typedef union {
+    HVX_Vector v;
+    uint8_t    b[VLEN];
+    uint16_t   h[VLEN_FP16];
+    uint32_t   w[VLEN_FP32];
+    __fp16     fp16[VLEN_FP16];
+    float      fp32[VLEN_FP32];
+} __attribute__((aligned(VLEN), packed)) HVX_VectorAlias;
+
 static inline HVX_Vector hvx_vec_splat_fp32(float i) {
     union {
         float   f;
@@ -243,19 +252,16 @@ static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint3
 }
 
 static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) {
-    union {
-        HVX_Vector v;
-        __fp16 d[64];
-    } u = { .v = v };
+    HVX_VectorAlias u = { .v = v };
 
     const uint32_t n0 = n / 16;
     const uint32_t n1 = n % 16;
     int            i  = 0;
     for (; i < n0; i++) {
-        htp_dump_fp16_line(pref, u.d + (16 * i), 16);
+        htp_dump_fp16_line(pref, u.fp16 + (16 * i), 16);
     }
     if (n1) {
-        htp_dump_fp16_line(pref, u.d + (16 * i), n1);
+        htp_dump_fp16_line(pref, u.fp16 + (16 * i), n1);
     }
 }
 
@@ -411,8 +417,8 @@ static inline HVX_Vector hvx_vec_fp32_reduce_sum_n(HVX_Vector in, unsigned int n
 
     HVX_Vector sum = in, sum_t;
     while (width < total) {
-        sum_t = Q6_V_vror_VR(sum, width);       // rotate right
-        sum   = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t)); // elementwise sum
+        sum_t = Q6_V_vror_VR(sum, width);                               // rotate right
+        sum   = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t));  // elementwise sum
         width = width << 1;
     }
     return sum;
@@ -491,7 +497,7 @@ static inline HVX_Vector hvx_vec_abs_fp16(HVX_Vector v) {
 static inline HVX_Vector hvx_vec_neg_fp16(HVX_Vector v) {
     // neg by setting the fp16 sign bit
     HVX_Vector mask = Q6_Vh_vsplat_R(0x8000);
-    return Q6_V_vor_VV(v, mask);
+    return Q6_V_vxor_VV(v, mask);
 }
 
 static inline HVX_Vector hvx_vec_abs_fp32(HVX_Vector v) {
@@ -506,7 +512,7 @@ static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) {
 #else
     // neg by setting the fp32 sign bit
     HVX_Vector mask = Q6_V_vsplat_R(0x80000000);
-    return Q6_V_vor_VV(v, mask);
+    return Q6_V_vxor_VV(v, mask);
 #endif  // __HTP_ARCH__ > 75
 }
 
@@ -720,6 +726,24 @@ static inline HVX_Vector hvx_vec_inverse_fp32(HVX_Vector v_sf) {
     return Q6_Vsf_equals_Vqf32(r_qf);
 }
 
+static inline HVX_Vector hvx_vec_inverse_fp32_guard(HVX_Vector v_sf) {
+    static const float    kInf     = INFINITY;
+    static const uint32_t kNanMask = 0x7fffffff;
+    static const uint32_t kNanMin  = 0x7f800000;
+
+    const HVX_Vector     inf      = hvx_vec_splat_fp32(kInf);
+    const HVX_VectorPred pred_inf = Q6_Q_vcmp_gt_VsfVsf(inf, v_sf);
+
+    HVX_Vector out = hvx_vec_inverse_fp32(v_sf);
+
+    const HVX_Vector     nan_mask   = Q6_V_vsplat_R(kNanMask);
+    const HVX_Vector     nan_min    = Q6_V_vsplat_R(kNanMin);
+    HVX_Vector           masked_out = Q6_V_vand_VV(out, nan_mask);
+    const HVX_VectorPred pred       = Q6_Q_vcmp_gtand_QVuwVuw(pred_inf, nan_min, masked_out);
+
+    return Q6_V_vmux_QVV(pred, out, Q6_V_vzero());
+}
+
 #define FAST_SIGMOID_LOG2F (0x3fb8aa3b)  // 1.442695022
 #define FAST_SIGMOID_C1    (0x3d009076)  // 0.03138777
 #define FAST_SIGMOID_C2    (0x3e8d74bd)  // 0.276281267
@@ -934,6 +958,16 @@ static inline HVX_Vector hvx_vec_rsqrt_fp32(HVX_Vector in_vec) {
     return Q6_Vsf_equals_Vqf32(temp);
 }
 
+static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v) {
+    static const float kMaxExp = -88.02f;  // log(INF)
+
+    const HVX_Vector     max_exp  = Q6_V_vsplat_R(*((uint32_t *) &kMaxExp));
+    const HVX_VectorPred pred_inf = Q6_Q_vcmp_gt_VsfVsf(v, max_exp);
+
+    HVX_Vector out = hvx_vec_fast_sigmoid_fp32(v);
+    return Q6_V_vmux_QVV(pred_inf, out, Q6_V_vzero());
+}
+
 static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
     int step_of_1 = num_elems >> 5;
     int remaining = num_elems - step_of_1 * VLEN_FP32;
@@ -945,7 +979,7 @@ static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t *
 
     #pragma unroll(4)
     for (int i = 0; i < step_of_1; i++) {
-        v_dst[i] = hvx_vec_fast_sigmoid_fp32(v_src[i]);
+        v_dst[i] = hvx_vec_fast_sigmoid_fp32_guard(v_src[i]);
     }
 }