|
@@ -321,6 +321,12 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) {
|
|
|
for (; i + 3 < n; i += 4) {
|
|
for (; i + 3 < n; i += 4) {
|
|
|
_mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
|
|
_mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
|
|
|
}
|
|
}
|
|
|
|
|
+#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
|
|
|
|
|
+ const int vlen = svcntw();
|
|
|
|
|
+ for (; i < n; i += vlen) {
|
|
|
|
|
+ const svbool_t pg = svwhilelt_b32_s32(i, n);
|
|
|
|
|
+ svst1_f32(pg, y + i, ggml_v_silu(pg, svld1_f32(pg, x + i)));
|
|
|
|
|
+ }
|
|
|
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
|
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
|
|
for (; i + 3 < n; i += 4) {
|
|
for (; i + 3 < n; i += 4) {
|
|
|
vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
|
|
vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
|
|
@@ -345,6 +351,12 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
|
|
|
for (; i + 3 < n; i += 4) {
|
|
for (; i + 3 < n; i += 4) {
|
|
|
_mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i)));
|
|
_mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i)));
|
|
|
}
|
|
}
|
|
|
|
|
+#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
|
|
|
|
|
+ const int vlen = svcntw();
|
|
|
|
|
+ for (; i < n; i += vlen) {
|
|
|
|
|
+ const svbool_t pg = svwhilelt_b32_s32(i, n);
|
|
|
|
|
+ svst1_f32(pg, y + i, svmul_f32_x(pg, ggml_v_silu(pg, svld1_f32(pg, x + i)), svld1_f32(pg, g + i)));
|
|
|
|
|
+ }
|
|
|
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
|
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
|
|
for (; i + 3 < n; i += 4) {
|
|
for (; i + 3 < n; i += 4) {
|
|
|
vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i)));
|
|
vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i)));
|
|
@@ -392,6 +404,15 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float
|
|
|
#endif
|
|
#endif
|
|
|
sum += (ggml_float)_mm_cvtss_f32(val);
|
|
sum += (ggml_float)_mm_cvtss_f32(val);
|
|
|
}
|
|
}
|
|
|
|
|
+#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
|
|
|
|
|
+ const int vlen = svcntw();
|
|
|
|
|
+ for (; i < n; i += vlen) {
|
|
|
|
|
+ const svbool_t pg = svwhilelt_b32_s32(i, n);
|
|
|
|
|
+ svfloat32_t val = ggml_v_expf(pg, svsub_f32_x(pg, svld1_f32(pg, x + i),
|
|
|
|
|
+ svdup_n_f32_x(pg, max)));
|
|
|
|
|
+ svst1_f32(pg, y + i, val);
|
|
|
|
|
+ sum += (ggml_float)svaddv_f32(pg, val);
|
|
|
|
|
+ }
|
|
|
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
|
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
|
|
for (; i + 3 < n; i += 4) {
|
|
for (; i + 3 < n; i += 4) {
|
|
|
float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
|
|
float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
|