Parcourir la source

hexagon: various Op fixes (#17135)

* hexagon: explicitly check for ops with zero nrows

llm_graph_context::build_inp_out_ids() can generate tensors with zero nrows.
Somehow other backends seems to handle this without obvious explicit checks.
In the hexagon case we need to check explicitly and skip them.

* hexagon: introduce fastdiv, fix test-backend-ops for ADD/SUB/MUL

Co-authored-by: chraac <chraac@gmail.com>

* hexagon: use fastdiv in ADD_ID

* hexagon: use ggml_op_is_empty and ggml_is_empty to check for NOPs

---------

Co-authored-by: chraac <chraac@gmail.com>
Max Krasnyansky il y a 2 mois
Parent
commit
c273d75375

+ 12 - 25
ggml/src/ggml-hexagon/ggml-hexagon.cpp

@@ -3156,26 +3156,17 @@ static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op
     return (op0 && op0->src[1] == op1->src[1]);
 }
 
+static inline bool is_compute_op(ggml_tensor *node)
+{
+    return !(ggml_op_is_empty(node->op) || ggml_is_empty(node));
+}
+
 // scan the graph and figure out last compute op index
 static inline int last_compute_op(ggml_cgraph * graph) {
-    int last;
+    int last = 0;
     for (int i = 0; i < graph->n_nodes; ++i) {
-        ggml_tensor * node = graph->nodes[i];
-
-        switch (node->op) {
-            case GGML_OP_MUL_MAT:
-            case GGML_OP_MUL_MAT_ID:
-            case GGML_OP_MUL:
-            case GGML_OP_ADD:
-            case GGML_OP_SUB:
-            case GGML_OP_RMS_NORM:
-            case GGML_OP_GLU:
-            case GGML_OP_ADD_ID:
-                last = i;
-                break;
-
-            default:
-                break;
+        if (is_compute_op(graph->nodes[i])) {
+            last = i;
         }
     }
 
@@ -3194,6 +3185,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
     for (int i = 0; i < graph->n_nodes; ++i) {
         ggml_tensor * node = graph->nodes[i];
 
+        if (!is_compute_op(node)) {
+            continue;
+        }
+
         uint32_t flags = 0;
 
         // skip quantizer if src1 is reused
@@ -3245,14 +3240,6 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
                 ggml_hexagon_rope(node, flags);
                 break;
 
-            // non-compute ops
-            case GGML_OP_NONE:
-            case GGML_OP_RESHAPE:
-            case GGML_OP_VIEW:
-            case GGML_OP_PERMUTE:
-            case GGML_OP_TRANSPOSE:
-                break;
-
             default:
                 GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
         }

+ 46 - 30
ggml/src/ggml-hexagon/htp/binary-ops.c

@@ -34,6 +34,11 @@ static hvx_elemwise_f32_func func_table_HVX[]     = { hvx_mul_f32, hvx_add_f32,
 static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f32_opt, hvx_sub_f32_opt };
 
 #define htp_binary_preamble            \
+    const struct htp_tensor * src0 = &octx->src0; \
+    const struct htp_tensor * src1 = &octx->src1; \
+    const struct htp_tensor * src2 = &octx->src2; \
+    struct htp_tensor *       dst  = &octx->dst;  \
+                                       \
     const uint32_t ne00 = src0->ne[0]; \
     const uint32_t ne01 = src0->ne[1]; \
     const uint32_t ne02 = src0->ne[2]; \
@@ -62,16 +67,15 @@ static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f
     const uint32_t nb0 = dst->nb[0];   \
     const uint32_t nb1 = dst->nb[1];   \
     const uint32_t nb2 = dst->nb[2];   \
-    const uint32_t nb3 = dst->nb[3];
-
-static void binary_job_f32_per_thread(const struct htp_tensor * src0,
-                                      const struct htp_tensor * src1,
-                                      struct htp_tensor *       dst,
-                                      uint8_t *                 spad_data,
-                                      uint32_t                  nth,
-                                      uint32_t                  ith,
-                                      uint32_t                  src0_nrows_per_thread,
-                                      enum htp_op               op) {
+    const uint32_t nb3 = dst->nb[3];   \
+                                       \
+    const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+
+static void binary_job_f32_per_thread(struct htp_ops_context * octx,
+                                      uint8_t *                spad_data,
+                                      uint32_t                 nth,
+                                      uint32_t                 ith,
+                                      enum htp_op              op) {
     htp_binary_preamble;
 
     const size_t src0_row_size = nb01;
@@ -107,16 +111,23 @@ static void binary_job_f32_per_thread(const struct htp_tensor * src0,
 
     uint8_t * restrict spad_data_th = spad_data + (ith * src0_row_size);
 
-    const uint32_t nr0 = ne00 / ne10;
-
     const uint8_t * restrict src0_ptr = (const uint8_t *) src0->data + (src0_start_row * src0_row_size);
     uint8_t * restrict dst_ptr        = (uint8_t *) dst->data + (src0_start_row * dst_row_size);
 
     const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
-    const uint8_t * restrict src1_ptr  = NULL;
+
+    const uint32_t ne02_ne01 = ne02 * ne01;
 
     for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
-        src1_ptr = data_src1 + (ir % src1_nrows) * src1_row_size;
+        const uint32_t i03 = fastdiv(ir, &octx->src0_div21);
+        const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1);
+        const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01);
+
+        const uint32_t i13 = fastmodulo(i03, ne13, &octx->src1_div3);
+        const uint32_t i12 = fastmodulo(i02, ne12, &octx->src1_div2);
+        const uint32_t i11 = fastmodulo(i01, ne11, &octx->src1_div1);
+
+        const uint8_t * restrict src1_ptr = data_src1 + i13 * nb13 + i12 * nb12 + i11 * src1_row_size;
 
         if (ir + 1 < src0_end_row) {
             htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
@@ -125,6 +136,7 @@ static void binary_job_f32_per_thread(const struct htp_tensor * src0,
             }
         }
 
+        const uint32_t nr0 = ne00 / ne10;
         if (nr0 > 1) {
             if ((1 == is_aligned) && (nr0 == ne00)) {
                 hvx_bcast_fp32_a(spad_data_th, *(float *) src1_ptr, nr0);
@@ -149,22 +161,17 @@ static void binary_job_f32_per_thread(const struct htp_tensor * src0,
          (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
-static void binary_add_id_job_f32_per_thread(const struct htp_tensor * src0,
-                                             const struct htp_tensor * src1,
-                                             const struct htp_tensor * src2,
-                                             struct htp_tensor *       dst,
-                                             uint8_t *                 spad_data,
-                                             uint32_t                  nth,
-                                             uint32_t                  ith,
-                                             uint32_t                  src0_nrows_per_thread,
-                                             hvx_elemwise_f32_func     func_HVX) {
+static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx,
+                                             uint8_t *                spad_data,
+                                             uint32_t                 nth,
+                                             uint32_t                 ith,
+                                             hvx_elemwise_f32_func    func_HVX) {
     htp_binary_preamble;
 
     const size_t src0_row_size = nb01;
     const size_t src1_row_size = nb11;
     const size_t dst_row_size  = nb1;
 
-    const uint32_t ne02_ne01  = ne02 * ne01;
     const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
 
     const uint32_t src0_start_row = src0_nrows_per_thread * ith;
@@ -187,10 +194,11 @@ static void binary_add_id_job_f32_per_thread(const struct htp_tensor * src0,
     const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
     uint8_t * restrict data_dst        = (uint8_t *) dst->data;
 
+    const uint32_t ne02_ne01  = ne02 * ne01;
     for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
         // src0 indices
-        const uint32_t i03 = ir / ne02_ne01;
-        const uint32_t i02 = (ir - i03 * ne02_ne01) / ne01;
+        const uint32_t i03 = fastdiv(ir, &octx->src0_div21);
+        const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1);
         const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01);
 
         // src1 indices
@@ -234,13 +242,11 @@ static void binary_job_dispatcher_f32(unsigned int n, unsigned int i, void * dat
         case HTP_OP_MUL:
         case HTP_OP_ADD:
         case HTP_OP_SUB:
-            binary_job_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->src1_spad.data, n, i,
-                                      octx->src0_nrows_per_thread, octx->op);
+            binary_job_f32_per_thread(octx, octx->src1_spad.data, n, i, octx->op);
             break;
 
         case HTP_OP_ADD_ID:
-            binary_add_id_job_f32_per_thread(&octx->src0, &octx->src1, &octx->src2, &octx->dst, octx->src0_spad.data, n,
-                                             i, octx->src0_nrows_per_thread, hvx_add_f32);
+            binary_add_id_job_f32_per_thread(octx, octx->src0_spad.data, n, i, hvx_add_f32);
             break;
 
         default:
@@ -321,6 +327,16 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) {
 
         octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
 
+        octx->src0_div21 = init_fastdiv_values(src0->ne[2] * src0->ne[1]);
+        octx->src0_div3  = init_fastdiv_values(src0->ne[3]);
+        octx->src0_div2  = init_fastdiv_values(src0->ne[2]);
+        octx->src0_div1  = init_fastdiv_values(src0->ne[1]);
+
+        octx->src1_div21 = init_fastdiv_values(src1->ne[2] * src1->ne[1]);
+        octx->src1_div3  = init_fastdiv_values(src1->ne[3]);
+        octx->src1_div2  = init_fastdiv_values(src1->ne[2]);
+        octx->src1_div1  = init_fastdiv_values(src1->ne[1]);
+
         worker_pool_run_func(octx->ctx->worker_pool, binary_op_func, octx, n_jobs);
     }
 

+ 4 - 4
ggml/src/ggml-hexagon/htp/htp-msg.h

@@ -119,10 +119,10 @@ static const char * htp_type_name(uint32_t t) {
 #define HTP_MAX_DIMS 4
 
 struct htp_tensor {
-    uint32_t data;              // Buffer offset in the messages, and data pointer on the NSP
-    uint32_t type;              // Data type
-    uint32_t ne[HTP_MAX_DIMS];  // Number of elements
-    uint32_t nb[HTP_MAX_DIMS];  // Stride in bytes (see ggml.h ggml_tensor)
+    uint32_t data;                // Buffer offset in the messages, and data pointer on the NSP
+    uint32_t type;                // Data type
+    uint32_t ne[HTP_MAX_DIMS];    // Number of elements
+    uint32_t nb[HTP_MAX_DIMS];    // Stride in bytes (see ggml.h ggml_tensor)
 };
 
 #define HTP_MAX_OP_PARAMS 64

+ 11 - 0
ggml/src/ggml-hexagon/htp/htp-ops.h

@@ -4,6 +4,7 @@
 #include "htp-ctx.h"
 #include "htp-msg.h"
 #include "worker-pool.h"
+#include "ops-utils.h"
 
 #include <assert.h>
 #include <stdint.h>
@@ -38,6 +39,16 @@ struct htp_ops_context {
     uint32_t src0_nrows_per_thread;
     uint32_t src1_nrows_per_thread;
 
+    struct fastdiv_values src0_div1;  // fastdiv values for ne1
+    struct fastdiv_values src0_div2;  // fastdiv values for ne2
+    struct fastdiv_values src0_div3;  // fastdiv values for ne3
+    struct fastdiv_values src0_div21; // fastdiv values for ne2 * ne1
+
+    struct fastdiv_values src1_div1;  // fastdiv values for ne1
+    struct fastdiv_values src1_div2;  // fastdiv values for ne2
+    struct fastdiv_values src1_div3;  // fastdiv values for ne3
+    struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
+
     uint32_t flags;
 };
 

+ 33 - 0
ggml/src/ggml-hexagon/htp/ops-utils.h

@@ -31,6 +31,39 @@ static inline uint32_t htp_round_up(uint32_t n, uint32_t m) {
     return m * ((n + m - 1) / m);
 }
 
+// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
+// Precompute mp (m' in the paper) and L such that division
+// can be computed using a multiply (high 32b of 64b result)
+// and a shift:
+//
+// n/d = (mulhi(n, mp) + n) >> L;
+struct fastdiv_values {
+    uint32_t mp;
+    uint32_t l;
+};
+
+static inline struct fastdiv_values init_fastdiv_values(uint32_t d) {
+    struct fastdiv_values result = { 0, 0 };
+    // compute L = ceil(log2(d));
+    while (result.l < 32 && ((uint32_t) 1 << result.l) < d) {
+        ++(result.l);
+    }
+
+    result.mp = (uint32_t) (((uint64_t) 1 << 32) * (((uint64_t) 1 << result.l) - d) / d + 1);
+    return result;
+}
+
+static inline uint32_t fastdiv(uint32_t n, const struct fastdiv_values * vals) {
+    // Compute high 32 bits of n * mp
+    const uint32_t hi = (uint32_t) (((uint64_t) n * vals->mp) >> 32);  // mulhi(n, mp)
+    // add n, apply bit shift
+    return (hi + n) >> vals->l;
+}
+
+static inline uint32_t fastmodulo(uint32_t n, uint32_t d, const struct fastdiv_values * vals) {
+    return n - fastdiv(n, vals) * d;
+}
+
 static inline void htp_l2fetch(const void * p, uint32_t height, uint32_t width, uint32_t stride) {
     const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
     asm volatile(" l2fetch(%0,%1) " : : "r"(p), "r"(control));