vec.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. #include "vec.h"
  2. #include <cassert>
  3. // precomputed gelu table for f16 (128 KB)
  4. ggml_fp16_t ggml_table_gelu_f16[1 << 16];
  5. // precomputed quick gelu table for f16 (128 KB)
  6. ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
  7. void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc) {
  8. assert(nrc == 1);
  9. GGML_UNUSED(nrc);
  10. GGML_UNUSED(bx);
  11. GGML_UNUSED(by);
  12. GGML_UNUSED(bs);
  13. #if defined(GGML_SIMD)
  14. float sumf = 0.0f;
  15. #if defined(__ARM_FEATURE_SVE)
  16. const int sve_register_length = ggml_cpu_get_sve_cnt() * 8;
  17. const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16
  18. const int ggml_f32_step = 8 * ggml_f32_epr; // choose 8 SVE registers
  19. const int np = (n & ~(ggml_f32_step - 1));
  20. svfloat32_t sum1 = svdup_n_f32(0.0f);
  21. svfloat32_t sum2 = svdup_n_f32(0.0f);
  22. svfloat32_t sum3 = svdup_n_f32(0.0f);
  23. svfloat32_t sum4 = svdup_n_f32(0.0f);
  24. svfloat32_t sum5 = svdup_n_f32(0.0f);
  25. svfloat32_t sum6 = svdup_n_f32(0.0f);
  26. svfloat32_t sum7 = svdup_n_f32(0.0f);
  27. svfloat32_t sum8 = svdup_n_f32(0.0f);
  28. svfloat32_t ax1,ax2,ax3,ax4,ax5,ax6,ax7,ax8;
  29. svfloat32_t ay1,ay2,ay3,ay4,ay5,ay6,ay7,ay8;
  30. for (int i = 0; i < np; i += ggml_f32_step) {
  31. ax1 = GGML_F32_VEC_LOAD(x + i);
  32. ay1 = GGML_F32_VEC_LOAD(y + i);
  33. sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
  34. ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
  35. ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
  36. sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2);
  37. ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
  38. ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
  39. sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3);
  40. ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
  41. ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
  42. sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4);
  43. ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
  44. ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
  45. sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5);
  46. ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
  47. ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
  48. sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6);
  49. ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
  50. ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
  51. sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7);
  52. ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
  53. ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
  54. sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8);
  55. }
  56. // leftovers
  57. // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
  58. const int np2 = (n & ~(ggml_f32_epr - 1));
  59. for (int i = np; i < np2; i += ggml_f32_epr) {
  60. ax1 = GGML_F32_VEC_LOAD(x + i);
  61. ay1 = GGML_F32_VEC_LOAD(y + i);
  62. sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
  63. }
  64. // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
  65. if (np2 < n) {
  66. svbool_t pg = svwhilelt_b32(np2, n);
  67. ax1 = svld1_f32(pg, x + np2);
  68. ay1 = svld1_f32(pg, y + np2);
  69. sum1 = svmad_f32_m(pg, ax1, ay1, sum1);
  70. }
  71. // reduce sum1,sum2 to sum1
  72. GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8);
  73. #elif defined(__riscv_v_intrinsic)
  74. int vl = __riscv_vsetvlmax_e32m8();
  75. vfloat32m1_t vs = __riscv_vfmv_v_f_f32m1(0.0f, 1);
  76. vfloat32m8_t vsum;
  77. vfloat32m8_t ax;
  78. vfloat32m8_t ay;
  79. vsum = __riscv_vfmv_v_f_f32m8_tu(vsum, 0.0f, vl);
  80. for (int i = 0; i < n; i += vl) {
  81. vl = __riscv_vsetvl_e32m8(n - i);
  82. ax = __riscv_vle32_v_f32m8_tu(ax, &x[i], vl);
  83. ay = __riscv_vle32_v_f32m8_tu(ay, &y[i], vl);
  84. vsum = __riscv_vfmacc_vv_f32m8_tu(vsum, ax, ay, vl);
  85. }
  86. vl = __riscv_vsetvlmax_e32m8();
  87. vs = __riscv_vfredusum_vs_f32m8_f32m1(vsum, vs, vl);
  88. sumf += __riscv_vfmv_f_s_f32m1_f32(vs);
  89. #else
  90. const int np = (n & ~(GGML_F32_STEP - 1));
  91. GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
  92. GGML_F32_VEC ax[GGML_F32_ARR];
  93. GGML_F32_VEC ay[GGML_F32_ARR];
  94. for (int i = 0; i < np; i += GGML_F32_STEP) {
  95. for (int j = 0; j < GGML_F32_ARR; j++) {
  96. ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
  97. ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
  98. sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
  99. }
  100. }
  101. // reduce sum0..sum3 to sum0
  102. GGML_F32_VEC_REDUCE(sumf, sum);
  103. // leftovers
  104. for (int i = np; i < n; ++i) {
  105. sumf += x[i]*y[i];
  106. }
  107. #endif
  108. #else
  109. // scalar
  110. ggml_float sumf = 0.0;
  111. for (int i = 0; i < n; ++i) {
  112. sumf += (ggml_float)(x[i]*y[i]);
  113. }
  114. #endif
  115. *s = sumf;
  116. }
  117. void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc) {
  118. assert(nrc == 1);
  119. GGML_UNUSED(nrc);
  120. GGML_UNUSED(bx);
  121. GGML_UNUSED(by);
  122. GGML_UNUSED(bs);
  123. int i = 0;
  124. ggml_float sumf = 0;
  125. #if defined(__AVX512BF16__)
  126. __m512 c1 = _mm512_setzero_ps();
  127. __m512 c2 = _mm512_setzero_ps();
  128. for (; i + 64 <= n; i += 64) {
  129. c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
  130. m512bh(_mm512_loadu_si512((y + i))));
  131. c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
  132. m512bh(_mm512_loadu_si512((y + i + 32))));
  133. }
  134. sumf += (ggml_float)_mm512_reduce_add_ps(c1);
  135. sumf += (ggml_float)_mm512_reduce_add_ps(c2);
  136. #elif defined(__AVX512F__)
  137. #define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
  138. __m512 c1 = _mm512_setzero_ps();
  139. __m512 c2 = _mm512_setzero_ps();
  140. for (; i + 32 <= n; i += 32) {
  141. c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
  142. c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
  143. }
  144. sumf += (ggml_float)_mm512_reduce_add_ps(c1);
  145. sumf += (ggml_float)_mm512_reduce_add_ps(c2);
  146. #undef LOAD
  147. #elif defined(__AVX2__) || defined(__AVX__)
  148. #if defined(__AVX2__)
  149. #define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
  150. #else
  151. #define LOAD(p) _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)), (_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_bsrli_si128(_mm_loadu_si128((const __m128i *)(p)), 8)), 16)), 1))
  152. #endif
  153. __m256 c1 = _mm256_setzero_ps();
  154. __m256 c2 = _mm256_setzero_ps();
  155. __m256 c3 = _mm256_setzero_ps();
  156. __m256 c4 = _mm256_setzero_ps();
  157. for (; i + 32 <= n; i += 32) {
  158. c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
  159. c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);
  160. c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);
  161. c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);
  162. }
  163. __m128 g;
  164. c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),
  165. _mm256_add_ps(c2, c4));
  166. g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
  167. _mm256_castps256_ps128(c1));
  168. g = _mm_add_ps(g, _mm_movehl_ps(g, g));
  169. g = _mm_add_ss(g, _mm_movehdup_ps(g));
  170. sumf += (ggml_float)_mm_cvtss_f32(g);
  171. #undef LOAD
  172. #endif
  173. for (; i < n; ++i) {
  174. sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
  175. GGML_BF16_TO_FP32(y[i]));
  176. }
  177. *s = sumf;
  178. }
  179. void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc) {
  180. assert(nrc == 1);
  181. GGML_UNUSED(nrc);
  182. GGML_UNUSED(bx);
  183. GGML_UNUSED(by);
  184. GGML_UNUSED(bs);
  185. ggml_float sumf = 0.0;
  186. #if defined(GGML_SIMD)
  187. #if defined(__ARM_FEATURE_SVE)
  188. const int sve_register_length = svcntb() * 8; //get vector length
  189. const int ggml_f16_epr = sve_register_length / 16; // running when 16
  190. const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
  191. const int np= (n & ~(ggml_f16_step - 1));
  192. svfloat16_t sum1 = svdup_n_f16(0.0f);
  193. svfloat16_t sum2 = svdup_n_f16(0.0f);
  194. svfloat16_t sum3 = svdup_n_f16(0.0f);
  195. svfloat16_t sum4 = svdup_n_f16(0.0f);
  196. svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
  197. svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
  198. for (int i = 0; i < np; i += ggml_f16_step) {
  199. ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
  200. ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
  201. sum1 = GGML_F16x_VEC_FMA(sum1, ax1, ay1);
  202. ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
  203. ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
  204. sum2 = GGML_F16x_VEC_FMA(sum2, ax2, ay2);
  205. ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
  206. ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
  207. sum3 = GGML_F16x_VEC_FMA(sum3, ax3, ay3);
  208. ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
  209. ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
  210. sum4 = GGML_F16x_VEC_FMA(sum4, ax4, ay4);
  211. ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
  212. ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
  213. sum1 = GGML_F16x_VEC_FMA(sum1, ax5, ay5);
  214. ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
  215. ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
  216. sum2 = GGML_F16x_VEC_FMA(sum2, ax6, ay6);
  217. ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
  218. ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
  219. sum3 = GGML_F16x_VEC_FMA(sum3, ax7, ay7);
  220. ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
  221. ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
  222. sum4 = GGML_F16x_VEC_FMA(sum4, ax8, ay8);
  223. }
  224. const int np2 = (n & ~(ggml_f16_epr - 1)); // round down to multiple of 8
  225. for (int k = np; k < np2; k += ggml_f16_epr) {
  226. svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
  227. svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
  228. sum1 = GGML_F16x_VEC_FMA(sum1, rx, ry);
  229. }
  230. if (np2 < n) {
  231. svbool_t pg = svwhilelt_b16(np2, n);
  232. svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
  233. svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
  234. sum1 = svmad_f16_x(pg, hx, hy, sum1);
  235. }
  236. GGML_F16x_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4);
  237. #elif defined(__riscv_v_intrinsic)
  238. #if defined(__riscv_zvfh)
  239. int vl = __riscv_vsetvlmax_e32m2();
  240. vfloat32m1_t vs = __riscv_vfmv_v_f_f32m1(0.0f, 1);
  241. vfloat32m2_t vsum;
  242. vfloat16m1_t ax;
  243. vfloat16m1_t ay;
  244. vsum = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vmv_v_x_u32m2(0, vl));
  245. for (int i = 0; i < n; i += vl) {
  246. vl = __riscv_vsetvl_e16m1(n - i);
  247. ax = __riscv_vle16_v_f16m1_tu(ax, (const _Float16 *)&x[i], vl);
  248. ay = __riscv_vle16_v_f16m1_tu(ay, (const _Float16 *)&y[i], vl);
  249. vsum = __riscv_vfwmacc_vv_f32m2_tu(vsum, ax, ay, vl);
  250. }
  251. vl = __riscv_vsetvlmax_e32m1();
  252. vfloat32m1_t ac0 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(vsum, 0), __riscv_vget_v_f32m2_f32m1(vsum, 1), vl);
  253. vs = __riscv_vfredusum_vs_f32m1_f32m1(ac0, vs, vl);
  254. sumf += __riscv_vfmv_f_s_f32m1_f32(vs);
  255. #else
  256. for (int i = 0; i < n; ++i) {
  257. sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
  258. }
  259. #endif // __riscv_zvfh
  260. #else
  261. const int np = (n & ~(GGML_F16_STEP - 1));
  262. GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
  263. GGML_F16_VEC ax[GGML_F16_ARR];
  264. GGML_F16_VEC ay[GGML_F16_ARR];
  265. for (int i = 0; i < np; i += GGML_F16_STEP) {
  266. for (int j = 0; j < GGML_F16_ARR; j++) {
  267. ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
  268. ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
  269. sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
  270. }
  271. }
  272. // reduce sum0..sum3 to sum0
  273. GGML_F16_VEC_REDUCE(sumf, sum);
  274. // leftovers
  275. for (int i = np; i < n; ++i) {
  276. sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
  277. }
  278. // if you hit this, you are likely running outside the FP range
  279. assert(!isnan(sumf) && !isinf(sumf));
  280. #endif
  281. #else
  282. for (int i = 0; i < n; ++i) {
  283. sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
  284. }
  285. #endif // GGML_SIMD
  286. *s = sumf;
  287. }
  288. void ggml_vec_silu_f32(const int n, float * y, const float * x) {
  289. int i = 0;
  290. #if defined(__AVX512F__) && defined(__AVX512DQ__)
  291. for (; i + 15 < n; i += 16) {
  292. _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));
  293. }
  294. #elif defined(__AVX2__) && defined(__FMA__)
  295. for (; i + 7 < n; i += 8) {
  296. _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));
  297. }
  298. #elif defined(__SSE2__)
  299. for (; i + 3 < n; i += 4) {
  300. _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
  301. }
  302. #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
  303. const int vlen = svcntw();
  304. for (; i < n; i += vlen) {
  305. const svbool_t pg = svwhilelt_b32_s32(i, n);
  306. svst1_f32(pg, y + i, ggml_v_silu(pg, svld1_f32(pg, x + i)));
  307. }
  308. #elif defined(__ARM_NEON) && defined(__aarch64__)
  309. for (; i + 3 < n; i += 4) {
  310. vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
  311. }
  312. #endif
  313. for (; i < n; ++i) {
  314. y[i] = ggml_silu_f32(x[i]);
  315. }
  316. }
  317. void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g) {
  318. int i = 0;
  319. #if defined(__AVX512F__) && defined(__AVX512DQ__)
  320. for (; i + 15 < n; i += 16) {
  321. _mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(g + i)));
  322. }
  323. #elif defined(__AVX2__) && defined(__FMA__)
  324. for (; i + 7 < n; i += 8) {
  325. _mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(g + i)));
  326. }
  327. #elif defined(__SSE2__)
  328. for (; i + 3 < n; i += 4) {
  329. _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i)));
  330. }
  331. #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
  332. const int vlen = svcntw();
  333. for (; i < n; i += vlen) {
  334. const svbool_t pg = svwhilelt_b32_s32(i, n);
  335. svst1_f32(pg, y + i, svmul_f32_x(pg, ggml_v_silu(pg, svld1_f32(pg, x + i)), svld1_f32(pg, g + i)));
  336. }
  337. #elif defined(__ARM_NEON) && defined(__aarch64__)
  338. for (; i + 3 < n; i += 4) {
  339. vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i)));
  340. }
  341. #elif defined(__riscv_v_intrinsic)
  342. for (int vl; i < n; i += vl) {
  343. vl = __riscv_vsetvl_e32m2(n - i);
  344. vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl);
  345. vfloat32m2_t vg = __riscv_vle32_v_f32m2(&g[i], vl);
  346. vfloat32m2_t vy = __riscv_vfmul_vv_f32m2(ggml_v_silu_m2(vx, vl), vg, vl);
  347. __riscv_vse32_v_f32m2(&y[i], vy, vl);
  348. }
  349. #endif
  350. for (; i < n; ++i) {
  351. y[i] = ggml_silu_f32(x[i]) * g[i];
  352. }
  353. }
  354. ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
  355. int i = 0;
  356. ggml_float sum = 0;
  357. #if defined(__AVX512F__) && defined(__AVX512DQ__)
  358. for (; i + 15 < n; i += 16) {
  359. __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
  360. _mm512_set1_ps(max)));
  361. _mm512_storeu_ps(y + i, val);
  362. sum += (ggml_float)_mm512_reduce_add_ps(val);
  363. }
  364. #elif defined(__AVX2__) && defined(__FMA__)
  365. for (; i + 7 < n; i += 8) {
  366. __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
  367. _mm256_set1_ps(max)));
  368. _mm256_storeu_ps(y + i, val);
  369. __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
  370. _mm256_castps256_ps128(val));
  371. val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
  372. val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
  373. sum += (ggml_float)_mm_cvtss_f32(val2);
  374. }
  375. #elif defined(__SSE2__)
  376. for (; i + 3 < n; i += 4) {
  377. __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
  378. _mm_set1_ps(max)));
  379. _mm_storeu_ps(y + i, val);
  380. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  381. val = _mm_add_ps(val, _mm_movehl_ps(val, val));
  382. val = _mm_add_ss(val, _mm_movehdup_ps(val));
  383. #else
  384. __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
  385. val = _mm_add_ps(val, tmp);
  386. tmp = _mm_movehl_ps(tmp, val);
  387. val = _mm_add_ss(val, tmp);
  388. #endif
  389. sum += (ggml_float)_mm_cvtss_f32(val);
  390. }
  391. #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
  392. const int vlen = svcntw();
  393. for (; i < n; i += vlen) {
  394. const svbool_t pg = svwhilelt_b32_s32(i, n);
  395. svfloat32_t val = ggml_v_expf(pg, svsub_f32_x(pg, svld1_f32(pg, x + i),
  396. svdup_n_f32_x(pg, max)));
  397. svst1_f32(pg, y + i, val);
  398. sum += (ggml_float)svaddv_f32(pg, val);
  399. }
  400. #elif defined(__ARM_NEON) && defined(__aarch64__)
  401. for (; i + 3 < n; i += 4) {
  402. float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
  403. vdupq_n_f32(max)));
  404. vst1q_f32(y + i, val);
  405. sum += (ggml_float)vaddvq_f32(val);
  406. }
  407. #elif defined(__riscv_v_intrinsic)
  408. vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);
  409. for (int avl; i < n; i += avl) {
  410. avl = __riscv_vsetvl_e32m2(n - i);
  411. vfloat32m2_t val = ggml_v_expf_m2(__riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], avl), max, avl), avl);
  412. __riscv_vse32_v_f32m2(&y[i], val, avl);
  413. vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, avl);
  414. }
  415. return (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);
  416. #endif
  417. for (; i < n; ++i) {
  418. float val = expf(x[i] - max);
  419. sum += (ggml_float)val;
  420. y[i] = val;
  421. }
  422. return sum;
  423. }
  424. ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) {
  425. // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i)
  426. int i = 0;
  427. ggml_float sum = 0;
  428. for (; i < n; ++i) {
  429. float val = x[i] - max;
  430. y[i] = val;
  431. sum += (ggml_float)expf(val);
  432. }
  433. return sum = (ggml_float)logf(sum);
  434. }