1
0
Эх сурвалжийг харах

ggml-hexagon: fix `rope` failure at `test-backend-ops` (#17565)

* fix test failure

* fix: correct scaling calculations in rope_cache_init

* fix: optimize element copying in rope_hex_f32 using memcpy

* fix: optimize loop boundaries in rope_hex_f32 for better performance

* feat: add profiling macros for performance measurement in operations
nullname 1 сар өмнө
parent
commit
34ce48d97a

+ 37 - 41
ggml/src/ggml-hexagon/htp/rope-ops.c

@@ -73,15 +73,15 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
     return (1 - MIN(1, MAX(0, y)));
 }
 
-static void rope_cache_init(const float   theta_base,
-                            float         freq_scale,
-                            const float * freq_factors,
-                            float *       corr_dims,
-                            uint32_t      ne0,
-                            float         ext_factor,
-                            float         mscale,
-                            float *       cache,
-                            float         theta_scale) {
+static void rope_cache_init(const float    theta_base,
+                            const float    freq_scale,
+                            const float *  freq_factors,
+                            float *        corr_dims,
+                            const uint32_t ne0,
+                            const float    ext_factor,
+                            const float    mscale,
+                            float *        cache,
+                            const float    theta_scale) {
     // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
     float theta = theta_base;
 
@@ -92,18 +92,19 @@ static void rope_cache_init(const float   theta_base,
 
         // Get n-d rotational scaling corrected for extrapolation
         float theta_interp = freq_scale * theta_extrap;
-        float theta2       = theta_interp;
+        float theta_final  = theta_interp;
+        float mscale_final = mscale;
 
         if (ext_factor != 0.0f) {
             float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
-            theta2         = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+            theta_final    = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
 
             // Get n-d magnitude scaling corrected for interpolation
-            mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+            mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale);
         }
 
-        cache[i0 + 0] = cosf(theta2) * mscale;
-        cache[i0 + 1] = sinf(theta2) * mscale;
+        cache[i0 + 0] = cosf(theta_final) * mscale_final;
+        cache[i0 + 1] = sinf(theta_final) * mscale_final;
 
         theta *= theta_scale;
     }
@@ -151,9 +152,9 @@ static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context
 }
 
 static void hvx_calc_rope_neox_f32(const float * restrict src0,
-                              float * restrict dst,
-                              const int num_elems,
-                              const float * restrict theta_cache) {
+                                   float * restrict dst,
+                                   const int num_elems,
+                                   const float * restrict theta_cache) {
     // for (int i = 0; i < num_elems; i += 2) {
     //const float cos_theta = theta_cache[i + 0];
     //const float sin_theta = theta_cache[i + 1];
@@ -192,7 +193,7 @@ static void hvx_calc_rope_neox_f32(const float * restrict src0,
         HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
         HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
 
-        *(HVX_Vector *) dst_curr          = Q6_Vsf_equals_Vqf32(v4);
+        *(HVX_Vector *) dst_curr               = Q6_Vsf_equals_Vqf32(v4);
         *(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);
 
         src0_curr += VLEN;
@@ -259,7 +260,7 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
                          const uint32_t       ir1,
                          int                  nth,
                          int                  ith,
-                         int                  opt_path) {
+                         const int            opt_path) {
     struct htp_ops_context * octx = rope_ctx->octx;
 
     const struct htp_tensor * src0 = &octx->src0;
@@ -267,8 +268,8 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
     const struct htp_tensor * src2 = &octx->src2;
     struct htp_tensor *       dst  = &octx->dst;
 
-    const int32_t mode  = rope_ctx->mode;
-    const bool is_neox  = mode & HTP_ROPE_TYPE_NEOX;
+    const int32_t mode    = rope_ctx->mode;
+    const bool    is_neox = mode & HTP_ROPE_TYPE_NEOX;
 
     htp_rope_preamble;
 
@@ -281,8 +282,9 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
         freq_factors = (const float *) src2->data;
     }
 
-    int ir = 0;
-
+    const uint32_t i1_end       = MIN(ir1, ne1);
+    const int32_t  half_dims    = rope_ctx->n_dims / 2;
+    const size_t   remain_bytes = (ne0 - rope_ctx->n_dims) * sizeof(float);
     for (uint32_t i3 = 0; i3 < ne3; i3++) {      // batch
         for (uint32_t i2 = 0; i2 < ne2; i2++) {  // seq-len
             const int32_t p = pos[i2];
@@ -290,14 +292,7 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
             rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor,
                             rope_ctx->attn_factor, wp0, rope_ctx->theta_scale);
 
-            for (uint32_t i1 = 0; i1 < ne1; i1++) {  // attn-heads
-                if (ir++ < ir0) {
-                    continue;
-                }
-                if (ir > ir1) {
-                    break;
-                }
-
+            for (uint32_t i1 = ir0; i1 < i1_end; i1++) {  // attn-heads
                 const float * src      = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01);
                 float *       dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1);
 
@@ -310,6 +305,9 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
                     } else {
                         hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
                     }
+
+                    src_loc += rope_ctx->n_dims;
+                    dst_data_loc += rope_ctx->n_dims;
                 } else {
                     for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
                         const float cos_theta = wp0[i0 + 0];
@@ -317,10 +315,10 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
 
                         if (is_neox) {
                             const float x0 = src_loc[0];
-                            const float x1 = src_loc[rope_ctx->n_dims/2];
+                            const float x1 = src_loc[half_dims];
 
-                            dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
-                            dst_data_loc[rope_ctx->n_dims/2] = x0 * sin_theta + x1 * cos_theta;
+                            dst_data_loc[0]         = x0 * cos_theta - x1 * sin_theta;
+                            dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta;
 
                             src_loc += 1;
                             dst_data_loc += 1;
@@ -335,15 +333,13 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
                             dst_data_loc += 2;
                         }
                     }
-                }
-
-                for (uint32_t i0 = rope_ctx->n_dims; i0 < ne0; i0 += 2) {
-                    dst_data_loc[0] = src_loc[0];
-                    dst_data_loc[1] = src_loc[1];
 
-                    src_loc += 2;
-                    dst_data_loc += 2;
+                    src_loc += (is_neox ? half_dims : 0);
+                    dst_data_loc += (is_neox ? half_dims : 0);
                 }
+
+                // TODO: use simd to speed up the remaining elements copy
+                memcpy(dst_data_loc, src_loc, remain_bytes);
             }
         }
     }