Просмотр исходного кода

hexagon: add support for ROPE_NEOX (#17458)

Max Krasnyansky 1 месяц назад
Родитель
Сommit
923ae3c619
2 измененных файлов с 81 добавлено и 8 удалено
  1. 1 1
      ggml/src/ggml-hexagon/ggml-hexagon.cpp
  2. 80 7
      ggml/src/ggml-hexagon/htp/rope-ops.c

+ 1 - 1
ggml/src/ggml-hexagon/ggml-hexagon.cpp

@@ -2229,7 +2229,7 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess
 
     int mode = op_params[2];
 
-    if ((mode & GGML_ROPE_TYPE_NEOX) || (mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) {
+    if ((mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) {
         return false;
     }
     if (mode & 1) {

+ 80 - 7
ggml/src/ggml-hexagon/htp/rope-ops.c

@@ -24,6 +24,10 @@
 #include "hvx-utils.h"
 #include "ops-utils.h"
 
+// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we cant include ggml.h
+#define HTP_ROPE_TYPE_NORMAL 0
+#define HTP_ROPE_TYPE_NEOX   2
+
 #define htp_rope_preamble              \
     const uint32_t ne00 = src0->ne[0]; \
     const uint32_t ne01 = src0->ne[1]; \
@@ -146,6 +150,57 @@ static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context
          rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor);
 }
 
+static void hvx_calc_rope_neox_f32(const float * restrict src0,
+                              float * restrict dst,
+                              const int num_elems,
+                              const float * restrict theta_cache) {
+    // for (int i = 0; i < num_elems; i += 2) {
+    //const float cos_theta = theta_cache[i + 0];
+    //const float sin_theta = theta_cache[i + 1];
+
+    //const float x0 = src[0];
+    //const float x1 = src[num_elems/2];
+
+    //dst[0] = x0*cos_theta - x1*sin_theta;
+    //dst[num_elems/2] = x0*sin_theta + x1*cos_theta;
+
+    //src += 1;
+    //dst += 1;
+    // }
+
+    const uint8_t * restrict src0_curr  = (const uint8_t *) src0;
+    const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
+    uint8_t * restrict dst_curr         = (uint8_t *) dst;
+
+    int step_of_1 = num_elems >> 6;  // 6 because we process two vectors at once
+    int half_size = (sizeof(float) * (num_elems / 2));
+
+    for (int i = 0; i < step_of_1; i++) {
+        HVX_Vector v0 = *(HVX_Vector *) src0_curr;
+        HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size);
+
+        HVX_Vector v2 = *(HVX_Vector *) theta_curr;
+        HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
+
+        HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4);  // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
+
+        HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin));
+        HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin));
+        HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_lo_W(vcos_sin));
+        HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_hi_W(vcos_sin));
+
+        HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
+        HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
+
+        *(HVX_Vector *) dst_curr          = Q6_Vsf_equals_Vqf32(v4);
+        *(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);
+
+        src0_curr += VLEN;
+        theta_curr += 2 * VLEN;
+        dst_curr += VLEN;
+    }
+}
+
 static void hvx_calc_rope_f32(const float * restrict src0,
                               float * restrict dst,
                               const int num_elems,
@@ -212,6 +267,9 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
     const struct htp_tensor * src2 = &octx->src2;
     struct htp_tensor *       dst  = &octx->dst;
 
+    const int32_t mode  = rope_ctx->mode;
+    const bool is_neox  = mode & HTP_ROPE_TYPE_NEOX;
+
     htp_rope_preamble;
 
     const int32_t * pos = (const int32_t *) src1->data;
@@ -247,20 +305,35 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
                 float *       dst_data_loc = dst_data;
 
                 if (1 == opt_path) {
-                    hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
+                    if (is_neox) {
+                        hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
+                    } else {
+                        hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
+                    }
                 } else {
                     for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
                         const float cos_theta = wp0[i0 + 0];
                         const float sin_theta = wp0[i0 + 1];
 
-                        const float x0 = src_loc[0];
-                        const float x1 = src_loc[1];
+                        if (is_neox) {
+                            const float x0 = src_loc[0];
+                            const float x1 = src_loc[rope_ctx->n_dims/2];
+
+                            dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
+                            dst_data_loc[rope_ctx->n_dims/2] = x0 * sin_theta + x1 * cos_theta;
+
+                            src_loc += 1;
+                            dst_data_loc += 1;
+                        } else {
+                            const float x0 = src_loc[0];
+                            const float x1 = src_loc[1];
 
-                        dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
-                        dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;
+                            dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
+                            dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;
 
-                        src_loc += 2;
-                        dst_data_loc += 2;
+                            src_loc += 2;
+                            dst_data_loc += 2;
+                        }
                     }
                 }