Quellcode durchsuchen

ggml-hexagon: gelu operation (#17921)

* feat: inital support for gelu using sigmoid approximation

* snapshot: faster gelu using polynomial approximation

* test: disable l2-block prefetch in polynomail approximation

* Revert "test: disable l2-block prefetch in polynomail approximation"

This reverts commit 72339994d45b2bed887e79994403c378d90b62b5.

* Revert "snapshot: faster gelu using polynomial approximation"

This reverts commit 2a787a61d11f9e63e5943a2e6d134b2f0c402ace.

* debug: temporarily disable unnecessary log message for debug purpose

* Feat: optiized unaligned sigmoid_f32

* Feat: larger l2prefetch block

* feat: apply unaligned-load optimization on mul and mul_scalar

* Revert "debug: temporarily disable unnecessary log message for debug purpose"

This reverts commit 84f2f23aa9f17e2fa826db969cd825d0ab192995.

* refactor: cleanup commented unused code

* chore: reformat code with clang-formatter to pass cli test

* Revert "chore: reformat code with clang-formatter to pass cli test"

This reverts commit 952877ec24732b12010c7fa7ed3fc8de4b74e718.

* fix: fix loop overflow

* chore: fix formating ci error
Shouyu vor 1 Monat
Ursprung
Commit
4470a0764a

+ 18 - 3
ggml/src/ggml-hexagon/ggml-hexagon.cpp

@@ -2161,8 +2161,14 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session
     }
 
     // src0, src1 & dst must be mapped to the same session
-    if (!hex_supported_buffer(sess, src0, src1, dst)) {
-        return false;
+    if(src1){
+        if (!hex_supported_buffer(sess, src0, src1, dst)) {
+            return false;
+        }
+    }else{
+        if (!hex_supported_buffer(sess, src0, dst)) {
+            return false;
+        }
     }
 
     return true;
@@ -2662,6 +2668,10 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) {
                 req.op    = HTP_OP_UNARY_SILU;
                 supported = true;
             }
+            else if (ggml_get_unary_op(dst) == GGML_UNARY_OP_GELU){
+                req.op    = HTP_OP_UNARY_GELU;
+                supported = true;
+            }
             break;
 
         case GGML_OP_GLU:
@@ -2677,6 +2687,7 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) {
         case GGML_OP_SOFT_MAX:
             req.op    = HTP_OP_SOFTMAX;
             supported = true;
+            break;
 
         default:
             break;
@@ -2956,6 +2967,8 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
             case GGML_OP_UNARY:
                 if (ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) {
                     ggml_hexagon_unary(node, flags);
+                } else if (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU) {
+                    ggml_hexagon_unary(node, flags);
                 }
                 break;
             case GGML_OP_GLU:
@@ -3254,7 +3267,6 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
     auto sess = static_cast<ggml_hexagon_session *>(dev->context);
 
     bool supp = false;
-
     switch (op->op) {
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
@@ -3294,6 +3306,9 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
             if (ggml_get_unary_op(op) == GGML_UNARY_OP_SILU) {
                 supp = ggml_hexagon_supported_activations(sess, op);
             }
+            else if (ggml_get_unary_op(op) == GGML_UNARY_OP_GELU){
+                supp = ggml_hexagon_supported_activations(sess, op);
+            }
             break;
 
         case GGML_OP_GLU:

+ 89 - 1
ggml/src/ggml-hexagon/htp/act-ops.c

@@ -255,6 +255,91 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
          src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
+
+static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0,
+                                       struct htp_tensor *       dst,
+                                       const int32_t *           op_params,
+                                       struct htp_spad *         src0_spad,
+                                       struct htp_spad *         dst_spad,
+                                       uint32_t                  nth,
+                                       uint32_t                  ith,
+                                       uint32_t                  src0_nrows_per_thread) {
+    htp_act_preamble2;
+
+    uint64_t t1, t2;
+    t1 = HAP_perf_get_qtimer_count();
+
+    const size_t src0_row_size = nb01;
+    const size_t dst_row_size  = nb1;
+
+    const uint32_t src0_nrows = ne01 * ne02 * ne03;
+
+    const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+    // no work for this thread
+    if (src0_start_row >= src0_end_row) {
+        return;
+    }
+
+    int is_aligned = 1;
+    int opt_path   = 0;
+    if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
+        is_aligned = 0;
+        FARF(HIGH, "silu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
+    }
+    if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
+        opt_path = 1;
+    }
+
+    const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
+    uint8_t * restrict data_dst        = (uint8_t *) dst->data;
+
+    uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
+    uint8_t * restrict dst_spad_data  = dst_spad->data + (ith * dst_row_size);
+
+    const int BLOCK = 8;
+    for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
+        const uint32_t block_end = MIN(ir + BLOCK, src0_end_row);
+
+        // Prefetch next block
+        if (block_end < src0_end_row) {
+            const float * restrict prefetch_ptr = (float *) (data_src0 + (block_end * src0_row_size));
+            htp_l2fetch(prefetch_ptr, 1, block_end * src0_row_size, src0_row_size);
+        }
+
+        // Process rows in current block
+        for (uint32_t ib = ir; ib < block_end; ib++) {
+            const float * restrict src0 = (float *) (data_src0 + (ib * src0_row_size));
+            float * restrict dst        = (float *) (data_dst + (ib * dst_row_size));
+
+            // gelu = x * sigmoid(1.702 * x) // current implementation
+            if (1 == opt_path) {
+                hvx_mul_scalar_f32((const uint8_t *) src0, (float) 1.702, (uint8_t *) src0_spad_data, ne0);
+                hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_data, (uint8_t *) src0_spad_data, ne0);
+                hvx_mul_f32_opt((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
+            } else {
+                hvx_mul_scalar_f32( (const uint8_t *) src0, (float)1.702, (uint8_t *) src0_spad_data, ne0);
+                hvx_sigmoid_f32((const uint8_t *) src0_spad_data, (uint8_t *) src0_spad_data, ne0);
+                hvx_mul_f32((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
+            }
+        }
+    }
+
+    t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "gelu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, ne00, ne01, ne02,
+         ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+static void unary_gelu_fp32(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = (struct htp_ops_context *) data;
+    unary_gelu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
+                               octx->src0_nrows_per_thread);
+}
+
+
+
 static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
                                        struct htp_tensor *       dst,
                                        const int32_t *           op_params,
@@ -371,7 +456,10 @@ static int execute_op_activations_fp32(struct htp_ops_context * octx) {
             act_op_func = glu_swiglu_oai_fp32;
             op_type     = "swiglu-oai-f32";
             break;
-
+        case HTP_OP_UNARY_GELU:
+            act_op_func = unary_gelu_fp32;
+            op_type     = "gelu-f32";
+            break;
         default:
             FARF(ERROR, "Unsupported activations Op %u\n", octx->op);
             return HTP_STATUS_NO_SUPPORT;

+ 6 - 5
ggml/src/ggml-hexagon/htp/htp-msg.h

@@ -51,11 +51,12 @@ enum htp_op {
     HTP_OP_MUL_MAT_ID     = 5,
     HTP_OP_RMS_NORM       = 6,
     HTP_OP_UNARY_SILU     = 7,
-    HTP_OP_GLU_SWIGLU     = 8,
-    HTP_OP_GLU_SWIGLU_OAI = 9,
-    HTP_OP_SOFTMAX        = 10,
-    HTP_OP_ADD_ID         = 11,
-    HTP_OP_ROPE           = 12,
+    HTP_OP_UNARY_GELU     = 8,
+    HTP_OP_GLU_SWIGLU     = 9,
+    HTP_OP_GLU_SWIGLU_OAI = 10,
+    HTP_OP_SOFTMAX        = 11,
+    HTP_OP_ADD_ID         = 12,
+    HTP_OP_ROPE           = 13,
     INVALID
 };
 

+ 85 - 12
ggml/src/ggml-hexagon/htp/hvx-utils.c

@@ -49,6 +49,8 @@ void hvx_mul_f32(const uint8_t * restrict src0,
         FARF(HIGH, "hvx_mul_f32: unaligned loop in hvx op, possibly slower execution\n");
     }
 
+
+    bool handled_leftover = false;
     if (0 == unaligned_loop) {
         HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0;
         HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1;
@@ -60,18 +62,59 @@ void hvx_mul_f32(const uint8_t * restrict src0,
             *vec_out++   = Q6_Vsf_equals_Vqf32(v);
         }
     } else {
+        int step_of_1 = num_elems_whole >> 5;  // divby 32, because 32 float = 128 bytes per HVX vector
+        int leftover_size = left_over * sizeof(float);
+
+
+        HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0;
+        HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1;
+        HVX_UVector * restrict vec_out = (HVX_UVector *) dst;
+
+        HVX_Vector slinep;
+        HVX_Vector slinec;
+        HVX_Vector sline;
+        HVX_Vector sline2p;
+        HVX_Vector sline2c;
+        HVX_Vector sline2;
+
+        slinep  = *vec_in1++;
+        sline2p = *vec_in2++;
         #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector in1 = *(HVX_UVector *) (src0 + i * SIZEOF_FP32);
-            HVX_Vector in2 = *(HVX_UVector *) (src1 + i * SIZEOF_FP32);
+        for (int i = step_of_1 - 1; i > 0; i--) {
+            slinec  = *vec_in1++;
+            sline2c = *vec_in2++;
+            sline   = Q6_V_valign_VVR(slinec, slinep, (size_t) src0);
+            sline2  = Q6_V_valign_VVR(sline2c, sline2p, (size_t) src1);
+
+            *((HVX_UVector *) (vec_out++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, sline2));
+            slinep                         = slinec;
+            sline2p                        = sline2c;
+        }
+        if (step_of_1 > 1) {
+            slinec  = htp_is_aligned(vec_in1, VLEN) && left_over == 0 ? slinep : *vec_in1++;
+            sline2c = htp_is_aligned(vec_in2, VLEN) && left_over == 0 ? sline2p : *vec_in2++;
+
+            sline                          = Q6_V_valign_VVR(slinec, slinep, (size_t) src0);
+            sline2                         = Q6_V_valign_VVR(sline2c, sline2p, (size_t) src1);
+            *((HVX_UVector *) (vec_out++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, sline2));
+            slinep                         = slinec;
+            sline2p                        = sline2c;
+        }
+        if (left_over > 0) {
+            slinec = (is_in_one_chunk(vec_in1, leftover_size, VLEN) ? slinep : *vec_in1++);
 
-            HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in1, in2);
+            sline   = Q6_V_valign_VVR(slinec, slinep, (size_t) src0);
+            sline2c = (is_in_one_chunk(vec_in2, leftover_size, VLEN) ? sline2p : *vec_in2++);
+            sline2  = Q6_V_valign_VVR(sline2c, sline2p, (size_t) src1);
 
-            *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
+            HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(sline, sline2);
+            hvx_vec_store_u(vec_out, leftover_size, Q6_Vsf_equals_Vqf32(out));
+            handled_leftover = true;
         }
     }
 
-    if (left_over > 0) {
+
+    if (left_over > 0 && !handled_leftover) {
         const float * src0f = (const float *) src0 + num_elems_whole;
         const float * src1f = (const float *) src1 + num_elems_whole;
         float *       dstf  = (float *) dst + num_elems_whole;
@@ -464,7 +507,7 @@ void hvx_mul_scalar_f32(const uint8_t * restrict src, const float val, uint8_t *
     }
 
     HVX_Vector val_vec = hvx_vec_splat_fp32(val);
-
+    bool handled_leftover = false;
     if (0 == unaligned_loop) {
         HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
         HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
@@ -475,17 +518,47 @@ void hvx_mul_scalar_f32(const uint8_t * restrict src, const float val, uint8_t *
             *vec_out++   = Q6_Vsf_equals_Vqf32(v);
         }
     } else {
+        int step_of_1 = num_elems >> 5;  // divby 32, because 32 float = 128 bytes per HVX vector
+        int leftover_size = left_over * sizeof(float);
+
+        HVX_Vector *  input_v_ptr  = (HVX_Vector *) src;
+        HVX_UVector * output_v_ptr = (HVX_UVector *) dst;
+
+        HVX_Vector slinep;
+        HVX_Vector slinec;
+        HVX_Vector sline;
+
+        slinep = *input_v_ptr++;
+
         #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
+        for (int i = step_of_1 - 1; i > 0; i--) {
+            slinec                              = *input_v_ptr++;
+            sline                               = Q6_V_valign_VVR(slinec, slinep, (size_t) src);
+            *((HVX_UVector *) (output_v_ptr++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, val_vec));
+            /* Prepare slinep for next iteration */
+            slinep                              = slinec;
+        }
 
-            HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, val_vec);
+        if (step_of_1 > 0) {
+            slinec = htp_is_aligned(input_v_ptr, VLEN) && left_over == 0 ? slinep : *input_v_ptr++;
+            sline  = Q6_V_valign_VVR(slinec, slinep, (size_t) src);
+            *((HVX_UVector *) (output_v_ptr++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, val_vec));
 
-            *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
+            slinep = slinec;
+        }
+
+        if (leftover_size > 0) {
+            slinec = (is_in_one_chunk(input_v_ptr, leftover_size, VLEN) ? slinep : *input_v_ptr++);
+
+            sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src);
+
+            HVX_Vector sout = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, val_vec));
+            hvx_vec_store_u(output_v_ptr, leftover_size, sout);
+            handled_leftover = true;
         }
     }
 
-    if (left_over > 0) {
+    if (left_over > 0 && !handled_leftover) {
         const float * srcf = (const float *) src + num_elems_whole;
         float *       dstf = (float *) dst + num_elems_whole;
 

+ 57 - 0
ggml/src/ggml-hexagon/htp/hvx-utils.h

@@ -265,12 +265,16 @@ static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t
     }
 }
 
+
+/* Return whether 'n' elements from vector are in the one chunk of 'chunk_size'. */
 static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
     uint32_t left_off  = (size_t) addr & (chunk_size - 1);
     uint32_t right_off = left_off + n;
     return right_off <= chunk_size;
 }
 
+
+
 static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) {
     HVX_VectorAlias u = { .v = v };
 
@@ -994,6 +998,59 @@ static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t *
     }
 }
 
+
+static inline void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems){
+    int step_of_1 = num_elems >> 5;  // divby 32, because 32 float = 128 bytes per HVX vector
+    int leftover = num_elems - (step_of_1 * VLEN_FP32);
+
+    int32_t leftover_size = leftover * sizeof(float);
+
+    static const float kMinExp = -87.f;  // 0
+    static const float kMaxExp = 87.f;   // 1
+
+    const HVX_Vector one     = hvx_vec_splat_fp32(1.f);
+    const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
+    const HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp);
+
+    const float *input = (float *)src;
+    float *output = (float *)dst;
+
+    HVX_Vector *  input_v_ptr  = (HVX_Vector *) input;
+    HVX_UVector * output_v_ptr = (HVX_UVector *) output;
+
+    HVX_Vector slinep;
+    HVX_Vector slinec;
+    HVX_Vector sline;
+
+    slinep = *input_v_ptr++;
+    #pragma unroll(4)
+    for (int i = step_of_1 - 1; i > 0; i--) {
+        slinec                              = *input_v_ptr++;
+        sline                               = Q6_V_valign_VVR(slinec, slinep, (size_t) input);
+        *((HVX_UVector *) (output_v_ptr++)) = hvx_vec_fast_sigmoid_fp32_guard(sline, one, max_exp, min_exp);
+        /* Prepare slinep for next iteration */
+        slinep                              = slinec;
+    }
+
+    if (step_of_1 > 0) {
+        slinec = htp_is_aligned(input_v_ptr, 128) && leftover == 0 ? slinep : *input_v_ptr++;
+        sline  = Q6_V_valign_VVR(slinec, slinep, (size_t) input);
+        *((HVX_UVector *) (output_v_ptr++)) = hvx_vec_fast_sigmoid_fp32_guard(sline, one, max_exp, min_exp);
+        ;
+
+        slinep = slinec;
+    }
+    if (leftover > 0) {
+        slinec = (is_in_one_chunk(input_v_ptr, leftover_size, 128) ? slinep : *input_v_ptr++);
+
+        sline = Q6_V_valign_VVR(slinec, slinep, (size_t) input);
+
+        HVX_Vector sout = hvx_vec_fast_sigmoid_fp32_guard(sline, one, max_exp, min_exp);
+        hvx_vec_store_u(output_v_ptr, leftover_size, sout);
+    }
+}
+
+
 float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems);
 void  hvx_mul_f32(const uint8_t * restrict src0,
                   const uint8_t * restrict src1,

+ 1 - 0
ggml/src/ggml-hexagon/htp/main.c

@@ -798,6 +798,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
                 break;
 
             case HTP_OP_UNARY_SILU:
+            case HTP_OP_UNARY_GELU:
                 if (n_bufs != 2) {
                     FARF(ERROR, "Bad act-req buffer list");
                     continue;