浏览代码

ggml-hexagon: swiglu_oai operation (#18114)

* snapshot: debug ggml-hexagon swiglu-oai

* fix: fix hvx_min_scalar_f32

* feat: working swiglu-oai

* chore: fix formating isue
Shouyu 1 月之前
父节点
当前提交
0a0bba05e8

+ 1 - 1
ggml/src/ggml-hexagon/ggml-hexagon.cpp

@@ -3312,7 +3312,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
             break;
 
         case GGML_OP_GLU:
-            if ((ggml_get_glu_op(op) == GGML_GLU_OP_SWIGLU) /* || (ggml_get_glu_op(op) == GGML_GLU_OP_SWIGLU_OAI) */) {
+            if ((ggml_get_glu_op(op) == GGML_GLU_OP_SWIGLU) || (ggml_get_glu_op(op) == GGML_GLU_OP_SWIGLU_OAI) ) {
                 supp = ggml_hexagon_supported_activations(sess, op);
             }
             break;

+ 1 - 1
ggml/src/ggml-hexagon/htp/act-ops.c

@@ -231,7 +231,7 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
         // x (src0_spad_data) = std::min(src0_p[k], limit);
         hvx_min_scalar_f32((const uint8_t *) src0, limit, src0_spad_data, nc);
         // y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit);
-        hvx_clamp_scalar_f32((const uint8_t *) src1, limit, limit, src1_spad_data, nc);
+        hvx_clamp_scalar_f32((const uint8_t *) src1, -limit, limit, src1_spad_data, nc);
         // y (src1_spad_data)  = y1 + 1.f
         hvx_add_scalar_f32(src1_spad_data, 1.0, src1_spad_data, nc);
         // x1 (dst_spad_data) = alpha * (x)

+ 69 - 35
ggml/src/ggml-hexagon/htp/hvx-utils.c

@@ -948,35 +948,45 @@ float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) {
 void hvx_min_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) {
     size_t left_over       = num_elems & (VLEN_FP32 - 1);
     size_t num_elems_whole = num_elems - left_over;
-
+    int unalign_address = 0;
     if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
         FARF(HIGH, "hvx_min_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
+        unalign_address = 1;
     }
 
-    assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole));
-
     const float * src_f = (const float *) src;
 
-    HVX_Vector vec_min = Q6_V_vsplat_R(val);
+    HVX_Vector vec_min = hvx_vec_splat_fp32(val);
 
-    HVX_Vector * restrict vec_in  = (HVX_Vector *) src;
-    HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
+    if(unalign_address == 0){
+        HVX_Vector * restrict vec_in  = (HVX_Vector *) src;
+        HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
 
-    #pragma unroll(4)
-    for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-        vec_min    = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++);
-        *vec_out++ = Q6_Vsf_equals_Vqf32(vec_min);
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector min_clamp    = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++);
+            *vec_out++ = (min_clamp);
+        }
+    }else{
+        HVX_UVector * restrict vec_in  = (HVX_Vector *) src;
+        HVX_UVector * restrict vec_out = (HVX_Vector *) dst;
+
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector min_clamp     = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++);
+            *vec_out++ = (min_clamp);
+        }
     }
 
-    if (left_over > 0) {
+    if (left_over > 0 ) {
         const float * srcf = (const float *) src + num_elems_whole;
         float *       dstf = (float *) dst + num_elems_whole;
 
-        HVX_Vector in = *(HVX_UVector *) srcf;
+        HVX_UVector in = *(HVX_UVector *) srcf;
 
-        vec_min = Q6_Vsf_vmin_VsfVsf(vec_min, in);
+        HVX_UVector min_clamp = Q6_Vsf_vmin_VsfVsf(vec_min, in);
 
-        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(vec_min));
+        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, (min_clamp));
     }
 }
 
@@ -988,46 +998,70 @@ void hvx_clamp_scalar_f32(const uint8_t * restrict src,
     size_t left_over       = num_elems & (VLEN_FP32 - 1);
     size_t num_elems_whole = num_elems - left_over;
 
+    int unalign_address = 0;
     if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
         FARF(HIGH, "hvx_clamp_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
+        unalign_address = 1;
     }
 
-    assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole));
-
-    HVX_Vector * restrict vec_in  = (HVX_Vector *) src;
-    HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
-
     HVX_Vector range_left  = hvx_vec_splat_fp32(limit_left);
     HVX_Vector range_right = hvx_vec_splat_fp32(limit_right);
 
-    #pragma unroll(4)
-    for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-        HVX_Vector in_vec = *vec_in++;
-        HVX_Vector temp_v = in_vec;
+    if(unalign_address == 0){
+        HVX_Vector * restrict vec_in  = (HVX_Vector *) src;
+        HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
 
-        HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right);
-        HVX_VectorPred pred_cap_left  = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec);
 
-        in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
-        in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, temp_v);
 
-        *vec_out++ = Q6_Vsf_equals_Vqf32(in_vec);
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector in_vec = *vec_in++;
+            HVX_Vector temp_v = in_vec;
+
+            HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right);
+            HVX_VectorPred pred_cap_left  = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec);
+
+            in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
+            in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, in_vec);
+
+            *vec_out++ = in_vec;
+        }
+
+    }else{
+
+        HVX_UVector * restrict vec_in  = (HVX_UVector *) src;
+        HVX_UVector * restrict vec_out = (HVX_UVector *) dst;
+
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector in_vec = *vec_in++;
+            HVX_Vector temp_v = in_vec;
+
+            HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right);
+            HVX_VectorPred pred_cap_left  = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec);
+
+            in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
+            in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, in_vec);
+
+            *vec_out++ = in_vec;
+        }
+
     }
 
     if (left_over > 0) {
         const float * srcf = (const float *) src + num_elems_whole;
         float *       dstf = (float *) dst + num_elems_whole;
 
-        HVX_Vector in = *(HVX_UVector *) srcf;
+        HVX_Vector in_vec = *(HVX_UVector *) srcf;
 
-        HVX_Vector temp_v = in;
+        HVX_Vector temp_v = in_vec;
 
-        HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in, range_right);
-        HVX_VectorPred pred_cap_left  = Q6_Q_vcmp_gt_VsfVsf(range_left, in);
+        HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right);
+        HVX_VectorPred pred_cap_left  = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec);
 
-        in = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
-        in = Q6_V_vmux_QVV(pred_cap_left, range_left, temp_v);
+        in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
+        in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, in_vec);
 
-        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(in));
+        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, in_vec);
     }
 }

+ 1 - 0
ggml/src/ggml-hexagon/htp/main.c

@@ -807,6 +807,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
                 break;
 
             case HTP_OP_GLU_SWIGLU:
+            case HTP_OP_GLU_SWIGLU_OAI:
             case HTP_OP_SOFTMAX:
                 if ((n_bufs != 2) && (n_bufs != 3)) {
                     FARF(ERROR, "Bad act-req buffer list");