|
|
@@ -7641,8 +7641,8 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
|
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
|
|
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
|
|
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
|
|
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
|
|
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
|
|
|
+ float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
|
|
+ float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
|
|
|
|
|
|
// use the output as the source for the next token-wise iterations
|
|
|
if (i2 > 0) { s0 = s; }
|
|
|
@@ -8070,6 +8070,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|
|
#define GGML_F32X_MUL GGML_F32x16_MUL
|
|
|
#define GGML_F32X_FMA GGML_F32x16_FMA
|
|
|
#define WKV_VECTOR_SIZE 16
|
|
|
+ #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
|
|
|
+ #define GGML_F32X GGML_F32xt
|
|
|
+ #define GGML_F32X_SET1 GGML_F32xt_SET1
|
|
|
+ #define GGML_F32X_LOAD GGML_F32xt_LOAD
|
|
|
+ #define GGML_F32X_STORE GGML_F32xt_STORE
|
|
|
+ #define GGML_F32X_MUL GGML_F32xt_MUL
|
|
|
+ #define GGML_F32X_FMA GGML_F32xt_FMA
|
|
|
+ #define WKV_VECTOR_SIZE 8
|
|
|
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
|
|
#define GGML_F32X GGML_F32x4
|
|
|
#define GGML_F32X_SET1 GGML_F32x4_SET1
|
|
|
@@ -8080,8 +8088,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|
|
#define WKV_VECTOR_SIZE 4
|
|
|
#endif
|
|
|
|
|
|
+ int wkv_vector_size;
|
|
|
#ifdef WKV_VECTOR_SIZE
|
|
|
- const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
|
|
|
+ #if defined(__ARM_FEATURE_SVE)
|
|
|
+ wkv_vector_size = svcntw();
|
|
|
+ #else
|
|
|
+ wkv_vector_size = WKV_VECTOR_SIZE;
|
|
|
+ #endif
|
|
|
+ const int64_t vec_count = head_size / wkv_vector_size;
|
|
|
|
|
|
for (int64_t t = 0; t < T; t++) {
|
|
|
size_t t_offset = t * t_stride;
|
|
|
@@ -8111,7 +8125,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|
|
GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
|
|
|
|
|
|
for (int64_t j = 0; j < vec_count; j++) {
|
|
|
- size_t base_j = j * WKV_VECTOR_SIZE;
|
|
|
+ size_t base_j = j * wkv_vector_size;
|
|
|
size_t t_h_j_offset = t_h_offset + base_j;
|
|
|
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
|
|
|
|
|
|
@@ -8136,7 +8150,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|
|
}
|
|
|
|
|
|
// Handle remaining elements, this will not be used.
|
|
|
- for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
|
|
|
+ for (int64_t j = vec_count * wkv_vector_size; j < head_size; j++) {
|
|
|
size_t t_h_j_offset = t_h_offset + j;
|
|
|
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
|
|
float v_val = v[t_h_j_offset];
|
|
|
@@ -8272,6 +8286,14 @@ static void ggml_compute_forward_gla_f32(
|
|
|
#define GGML_F32X_MUL GGML_F32x16_MUL
|
|
|
#define GGML_F32X_FMA GGML_F32x16_FMA
|
|
|
#define GLA_VECTOR_SIZE 16
|
|
|
+ #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
|
|
|
+ #define GGML_F32X GGML_F32xt
|
|
|
+ #define GGML_F32X_SET1 GGML_F32xt_SET1
|
|
|
+ #define GGML_F32X_LOAD GGML_F32xt_LOAD
|
|
|
+ #define GGML_F32X_STORE GGML_F32xt_STORE
|
|
|
+ #define GGML_F32X_MUL GGML_F32xt_MUL
|
|
|
+ #define GGML_F32X_FMA GGML_F32xt_FMA
|
|
|
+ #define GLA_VECTOR_SIZE 8
|
|
|
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
|
|
#define GGML_F32X GGML_F32x4
|
|
|
#define GGML_F32X_SET1 GGML_F32x4_SET1
|
|
|
@@ -8282,8 +8304,14 @@ static void ggml_compute_forward_gla_f32(
|
|
|
#define GLA_VECTOR_SIZE 4
|
|
|
#endif
|
|
|
|
|
|
+ int gla_vector_size;
|
|
|
#ifdef GLA_VECTOR_SIZE
|
|
|
- const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
|
|
|
+ #if defined(__ARM_FEATURE_SVE)
|
|
|
+ gla_vector_size = svcntw();
|
|
|
+ #else
|
|
|
+ gla_vector_size = GLA_VECTOR_SIZE;
|
|
|
+ #endif
|
|
|
+ const int64_t vec_count = head_size / gla_vector_size;
|
|
|
|
|
|
for (int64_t t = 0; t < T; t++) {
|
|
|
size_t t_offset = t * t_stride;
|
|
|
@@ -8310,7 +8338,7 @@ static void ggml_compute_forward_gla_f32(
|
|
|
GGML_F32X g_vec = GGML_F32X_SET1(g_val);
|
|
|
|
|
|
for (int64_t j = 0; j < vec_count; j++) {
|
|
|
- size_t base_j = j * GLA_VECTOR_SIZE;
|
|
|
+ size_t base_j = j * gla_vector_size;
|
|
|
size_t t_h_j_offset = t_h_offset + base_j;
|
|
|
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
|
|
|
|
|
|
@@ -8334,7 +8362,7 @@ static void ggml_compute_forward_gla_f32(
|
|
|
}
|
|
|
|
|
|
// Handle remaining elements, this will not be used.
|
|
|
- for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
|
|
|
+ for (int64_t j = vec_count * gla_vector_size; j < head_size; j++) {
|
|
|
size_t t_h_j_offset = t_h_offset + j;
|
|
|
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
|
|
float v_val = v[t_h_j_offset];
|
|
|
@@ -8443,83 +8471,126 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
|
|
|
int64_t h_stride_2d = head_size * head_size;
|
|
|
|
|
|
#if defined(GGML_SIMD)
|
|
|
- for (int64_t t = 0; t < T; t++) {
|
|
|
- int64_t t_offset = t * t_stride;
|
|
|
- int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
|
|
- float * state_cur = state + state_offset;
|
|
|
- float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
|
|
-
|
|
|
- for (int64_t h = h_start; h < h_end; h++) {
|
|
|
- int64_t h_offset = h * h_stride;
|
|
|
- int64_t t_h_offset = t_offset + h_offset;
|
|
|
- int64_t h_2d_offset = h * h_stride_2d;
|
|
|
-
|
|
|
- for (int64_t ii = 0; ii < head_size; ii++) {
|
|
|
- int64_t t_h_i_offset = t_h_offset + ii;
|
|
|
- int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
|
|
|
-
|
|
|
- GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
|
|
|
+ #if defined(__ARM_FEATURE_SVE)
|
|
|
+ // scalar Route to scalar implementation //TODO: Write SVE code
|
|
|
+ for (int64_t t = 0; t < T; t++) {
|
|
|
+ int64_t t_offset = t * t_stride;
|
|
|
+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
|
|
+ float * state_cur = state + state_offset;
|
|
|
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
|
|
+
|
|
|
+ for (int64_t h = h_start; h < h_end; h++) {
|
|
|
+ int64_t h_offset = h * h_stride;
|
|
|
+ int64_t t_h_offset = t_offset + h_offset;
|
|
|
+ int64_t h_2d_offset = h * h_stride_2d;
|
|
|
+
|
|
|
+ for (int64_t i = 0; i < head_size; i++) {
|
|
|
+ int64_t t_h_i_offset = t_h_offset + i;
|
|
|
+ int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
|
|
+
|
|
|
+ float v_val = v[t_h_i_offset];
|
|
|
+
|
|
|
+ float sa = 0, result = 0;
|
|
|
+ for (int64_t j = 0; j < head_size; j++) {
|
|
|
+ sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
|
|
|
+ }
|
|
|
|
|
|
- float sa = 0;
|
|
|
- {
|
|
|
- GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
|
|
- GGML_F32_VEC ax[GGML_F32_ARR];
|
|
|
- GGML_F32_VEC ay[GGML_F32_ARR];
|
|
|
- for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
|
|
|
- for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
|
|
|
- ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
|
|
|
- ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
|
|
|
- sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
|
|
|
- }
|
|
|
+ for (int64_t j = 0; j < head_size; j++) {
|
|
|
+ int64_t t_h_j_offset = t_h_offset + j;
|
|
|
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
|
|
+
|
|
|
+ float r_val = r[t_h_j_offset];
|
|
|
+ float w_val = w[t_h_j_offset];
|
|
|
+ float k_val = k[t_h_j_offset];
|
|
|
+ float b_val = b[t_h_j_offset];
|
|
|
+ float kv_val = v_val * k_val;
|
|
|
+ float prev_state_val = state_prev[h_2d_i_j_offset];
|
|
|
+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
|
|
+ result += state_cur[h_2d_i_j_offset] * r_val;
|
|
|
}
|
|
|
- GGML_F32_VEC_REDUCE(sa, sum);
|
|
|
+ dst_data[t_h_i_offset] = result;
|
|
|
}
|
|
|
+ }
|
|
|
+ }
|
|
|
+ #else
|
|
|
+ for (int64_t t = 0; t < T; t++) {
|
|
|
+ int64_t t_offset = t * t_stride;
|
|
|
+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
|
|
+ float * state_cur = state + state_offset;
|
|
|
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
|
|
+
|
|
|
+ for (int64_t h = h_start; h < h_end; h++) {
|
|
|
+ int64_t h_offset = h * h_stride;
|
|
|
+ int64_t t_h_offset = t_offset + h_offset;
|
|
|
+ int64_t h_2d_offset = h * h_stride_2d;
|
|
|
+
|
|
|
+ for (int64_t ii = 0; ii < head_size; ii++) {
|
|
|
+ int64_t t_h_i_offset = t_h_offset + ii;
|
|
|
+ int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
|
|
|
+
|
|
|
+ GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
|
|
|
+
|
|
|
+ float sa = 0;
|
|
|
+ {
|
|
|
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
|
|
+ GGML_F32_VEC ax[GGML_F32_ARR];
|
|
|
+ GGML_F32_VEC ay[GGML_F32_ARR];
|
|
|
+ for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
|
|
|
+ for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
|
|
|
+ ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
|
|
|
+ ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
|
|
|
+ sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ GGML_F32_VEC_REDUCE(sa, sum);
|
|
|
+ }
|
|
|
|
|
|
- GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
|
|
|
+ GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
|
|
|
|
|
|
- int64_t j = 0;
|
|
|
- GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
|
|
- for (; j < head_size; j += GGML_F32_STEP) {
|
|
|
- for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
|
|
|
- int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
|
|
|
- int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
|
|
|
+ int64_t j = 0;
|
|
|
+ GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
|
|
+ for (; j < head_size; j += GGML_F32_STEP) {
|
|
|
+ for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
|
|
|
+ int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
|
|
|
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
|
|
|
|
|
|
- GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
|
|
|
- GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
|
|
|
- GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
|
|
|
- GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
|
|
|
+ GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
|
|
|
+ GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
|
|
|
+ GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
|
|
|
+ GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
|
|
|
|
|
|
- k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
|
|
|
+ k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
|
|
|
|
|
|
- GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
|
|
|
- // kv + s * decay + sa * b
|
|
|
- state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
|
|
|
- state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
|
|
|
- GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
|
|
|
+ GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
|
|
|
+ // kv + s * decay + sa * b
|
|
|
+ state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
|
|
|
+ state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
|
|
|
+ GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
|
|
|
|
|
|
- result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
|
|
|
+ result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
|
|
|
+
|
|
|
+ // There shouldn't be left-overs though.
|
|
|
+ for (; j < head_size; j++) {
|
|
|
+ int64_t t_h_j_offset = t_h_offset + j;
|
|
|
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
|
|
+
|
|
|
+ float r_val = r[t_h_j_offset];
|
|
|
+ float w_val = w[t_h_j_offset];
|
|
|
+ float k_val = k[t_h_j_offset];
|
|
|
+ float b_val = b[t_h_j_offset];
|
|
|
+ float kv_val = v[t_h_i_offset] * k_val;
|
|
|
+
|
|
|
+ float prev_state_val = state_prev[h_2d_i_j_offset];
|
|
|
+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
|
|
+ dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
|
|
|
}
|
|
|
- }
|
|
|
- GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
|
|
|
-
|
|
|
- // There shouldn't be left-overs though.
|
|
|
- for (; j < head_size; j++) {
|
|
|
- int64_t t_h_j_offset = t_h_offset + j;
|
|
|
- int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
|
|
-
|
|
|
- float r_val = r[t_h_j_offset];
|
|
|
- float w_val = w[t_h_j_offset];
|
|
|
- float k_val = k[t_h_j_offset];
|
|
|
- float b_val = b[t_h_j_offset];
|
|
|
- float kv_val = v[t_h_i_offset] * k_val;
|
|
|
-
|
|
|
- float prev_state_val = state_prev[h_2d_i_j_offset];
|
|
|
- state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
|
|
- dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- }
|
|
|
+ #endif
|
|
|
#else
|
|
|
for (int64_t t = 0; t < T; t++) {
|
|
|
int64_t t_offset = t * t_stride;
|