vec.h 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121
  1. // Vectorized functions for fundamental operations
  2. #pragma once
  3. #include "ggml-impl.h"
  4. #include "simd-mappings.h"
  5. #include "ggml.h"
  6. #include "ggml-cpu.h"
  7. #if defined(GGML_USE_ACCELERATE)
  8. #include <Accelerate/Accelerate.h>
  9. #endif
  10. // floating point type used to accumulate sums
  11. typedef double ggml_float;
  12. #define GGML_GELU_FP16
  13. #define GGML_GELU_QUICK_FP16
  14. #define GGML_SOFT_MAX_UNROLL 4
  15. #define GGML_VEC_DOT_UNROLL 2
  16. #define GGML_VEC_MAD_UNROLL 32
  17. #ifdef __cplusplus
  18. extern "C" {
  19. #endif
  20. //
  21. // global data
  22. //
  23. // precomputed gelu table for f16 (128 KB)
  24. extern ggml_fp16_t ggml_table_gelu_f16[1 << 16];
  25. // precomputed quick gelu table for f16 (128 KB)
  26. extern ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
  27. //
  28. // fundamental operations
  29. //
  30. void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
  31. void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc);
  32. void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
  33. void ggml_vec_silu_f32(const int n, float * y, const float * x);
  34. ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max);
  35. ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max);
  36. inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
  37. inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
  38. inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
  39. inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
  40. inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
  41. inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
  42. inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) {
  43. int i = 0;
  44. #if defined(__AVX2__)
  45. for (; i + 7 < n; i += 8) {
  46. __m256 vx = _mm256_loadu_ps(x + i);
  47. __m256 vy = _mm256_loadu_ps(y + i);
  48. __m256 vz = _mm256_add_ps(vx, vy);
  49. _mm256_storeu_ps(z + i, vz);
  50. }
  51. #endif
  52. for (; i < n; ++i) {
  53. z[i] = x[i] + y[i];
  54. }
  55. }
  56. inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
  57. for (int i = 0; i < n; ++i) {
  58. z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) + GGML_CPU_FP16_TO_FP32(y[i]));
  59. }
  60. }
  61. inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
  62. inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
  63. inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
  64. inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; }
  65. inline static void ggml_vec_sub_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
  66. for (int i = 0; i < n; ++i) {
  67. z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) - GGML_CPU_FP16_TO_FP32(y[i]));
  68. }
  69. }
  70. inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
  71. inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
  72. inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
  73. inline static void ggml_vec_neg_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  74. for (int i = 0; i < n; ++i) {
  75. y[i] = GGML_CPU_FP32_TO_FP16(-GGML_CPU_FP16_TO_FP32(x[i]));
  76. }
  77. }
  78. inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
  79. inline static void ggml_vec_mul_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
  80. for (int i = 0; i < n; ++i) {
  81. z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) * GGML_CPU_FP16_TO_FP32(y[i]));
  82. }
  83. }
  84. inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
  85. inline static void ggml_vec_div_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
  86. for (int i = 0; i < n; ++i) {
  87. z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) / GGML_CPU_FP16_TO_FP32(y[i]));
  88. }
  89. }
  90. // compute GGML_VEC_DOT_UNROLL dot products at once
  91. // xs - x row stride in bytes
  92. inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GGML_RESTRICT s, void * GGML_RESTRICT xv, ggml_fp16_t * GGML_RESTRICT y) {
  93. ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
  94. ggml_fp16_t * GGML_RESTRICT x[GGML_VEC_DOT_UNROLL];
  95. for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
  96. x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
  97. }
  98. #if defined(GGML_SIMD)
  99. const int np = (n & ~(GGML_F16_STEP - 1));
  100. GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
  101. GGML_F16_VEC ax[GGML_F16_ARR];
  102. GGML_F16_VEC ay[GGML_F16_ARR];
  103. for (int i = 0; i < np; i += GGML_F16_STEP) {
  104. for (int j = 0; j < GGML_F16_ARR; j++) {
  105. ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
  106. for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
  107. ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
  108. sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
  109. }
  110. }
  111. }
  112. // reduce sum0..sum3 to sum0
  113. for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
  114. GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
  115. }
  116. // leftovers
  117. for (int i = np; i < n; ++i) {
  118. for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
  119. sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
  120. }
  121. }
  122. #else
  123. for (int i = 0; i < n; ++i) {
  124. for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
  125. sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
  126. }
  127. }
  128. #endif
  129. for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
  130. s[i] = (float)sumf[i];
  131. }
  132. }
  133. inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const float * GGML_RESTRICT x, const float v) {
  134. #if defined(GGML_SIMD)
  135. #if defined(__ARM_FEATURE_SVE)
  136. const int sve_register_length = ggml_cpu_get_sve_cnt() * 8;
  137. const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16
  138. const int ggml_f32_step = 8 * ggml_f32_epr; // choose 8 SVE registers
  139. GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
  140. const int np = (n & ~(ggml_f32_step - 1));
  141. svfloat32_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
  142. svfloat32_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
  143. for (int i = 0; i < np; i += ggml_f32_step) {
  144. ax1 = GGML_F32_VEC_LOAD(x + i);
  145. ay1 = GGML_F32_VEC_LOAD(y + i);
  146. ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
  147. GGML_F32_VEC_STORE(y + i, ay1);
  148. ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
  149. ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
  150. ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx);
  151. GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
  152. ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
  153. ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
  154. ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx);
  155. GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);
  156. ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
  157. ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
  158. ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx);
  159. GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);
  160. ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
  161. ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
  162. ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx);
  163. GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);
  164. ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
  165. ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
  166. ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx);
  167. GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);
  168. ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
  169. ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
  170. ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx);
  171. GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);
  172. ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
  173. ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
  174. ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx);
  175. GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
  176. }
  177. // leftovers
  178. // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
  179. const int np2 = (n & ~(ggml_f32_epr - 1));
  180. for (int i = np; i < np2; i += ggml_f32_epr) {
  181. ax1 = GGML_F32_VEC_LOAD(x + i);
  182. ay1 = GGML_F32_VEC_LOAD(y + i);
  183. ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
  184. GGML_F32_VEC_STORE(y + i, ay1);
  185. }
  186. // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
  187. if (np2 < n) {
  188. svbool_t pg =svwhilelt_b32(np2, n);
  189. ax1 = svld1_f32(pg, x + np2);
  190. ay1 = svld1_f32(pg, y + np2);
  191. ay1 = svmad_f32_m(pg, ax1, vx, ay1);
  192. svst1_f32(pg, y + np2, ay1);
  193. }
  194. #else
  195. const int np = (n & ~(GGML_F32_STEP - 1));
  196. GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
  197. GGML_F32_VEC ax[GGML_F32_ARR];
  198. GGML_F32_VEC ay[GGML_F32_ARR];
  199. for (int i = 0; i < np; i += GGML_F32_STEP) {
  200. for (int j = 0; j < GGML_F32_ARR; j++) {
  201. ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
  202. ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
  203. ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
  204. GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
  205. }
  206. }
  207. // leftovers
  208. for (int i = np; i < n; ++i) {
  209. y[i] += x[i]*v;
  210. }
  211. #endif
  212. #else
  213. // scalar
  214. for (int i = 0; i < n; ++i) {
  215. y[i] += x[i]*v;
  216. }
  217. #endif
  218. }
  219. inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
  220. #if defined(GGML_SIMD)
  221. const int np = (n & ~(GGML_F16_STEP - 1));
  222. GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
  223. GGML_F16_VEC ax[GGML_F16_ARR];
  224. GGML_F16_VEC ay[GGML_F16_ARR];
  225. for (int i = 0; i < np; i += GGML_F16_STEP) {
  226. for (int j = 0; j < GGML_F16_ARR; j++) {
  227. ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
  228. ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
  229. ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
  230. GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
  231. }
  232. }
  233. // leftovers
  234. for (int i = np; i < n; ++i) {
  235. y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
  236. }
  237. #else
  238. // scalar
  239. for (int i = 0; i < n; ++i) {
  240. y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
  241. }
  242. #endif
  243. }
  244. // xs and vs are byte strides of x and v
  245. inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * GGML_RESTRICT y, const float * GGML_RESTRICT xv, const float * GGML_RESTRICT vv) {
  246. const float * GGML_RESTRICT x[GGML_VEC_MAD_UNROLL];
  247. const float * GGML_RESTRICT v[GGML_VEC_MAD_UNROLL];
  248. for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) {
  249. x[i] = (const float *) ((const char *) xv + i*xs);
  250. v[i] = (const float *) ((const char *) vv + i*vs);
  251. }
  252. #if defined(GGML_SIMD)
  253. #if defined(__ARM_FEATURE_SVE)
  254. // scalar Route to scalar implementation //TODO: Write SVE code
  255. for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
  256. for (int i = 0; i < n; ++i) {
  257. y[i] += x[k][i]*v[k][0];
  258. }
  259. }
  260. #else
  261. const int np = (n & ~(GGML_F32_STEP - 1));
  262. GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
  263. for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
  264. vx[k] = GGML_F32_VEC_SET1(v[k][0]);
  265. }
  266. GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
  267. GGML_F32_VEC ay[GGML_F32_ARR];
  268. for (int i = 0; i < np; i += GGML_F32_STEP) {
  269. for (int j = 0; j < GGML_F32_ARR; j++) {
  270. ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
  271. for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
  272. ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
  273. ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
  274. }
  275. GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
  276. }
  277. }
  278. // leftovers
  279. for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
  280. for (int i = np; i < n; ++i) {
  281. y[i] += x[k][i]*v[k][0];
  282. }
  283. }
  284. #endif
  285. #else
  286. // scalar
  287. for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
  288. for (int i = 0; i < n; ++i) {
  289. y[i] += x[k][i]*v[k][0];
  290. }
  291. }
  292. #endif
  293. }
  294. inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) {
  295. #if defined(GGML_USE_ACCELERATE)
  296. vDSP_vsmsa(x, 1, &s, &b, y, 1, n);
  297. #elif defined(GGML_SIMD)
  298. #if defined(__ARM_FEATURE_SVE)
  299. // scalar ; TODO: Write SVE code
  300. for (int i = 0; i < n; ++i) {
  301. y[i] = x[i]*s + b;
  302. }
  303. #else
  304. const int np = (n & ~(GGML_F32_STEP - 1));
  305. GGML_F32_VEC vs = GGML_F32_VEC_SET1(s);
  306. GGML_F32_VEC vb = GGML_F32_VEC_SET1(b);
  307. GGML_F32_VEC ay[GGML_F32_ARR];
  308. for (int i = 0; i < np; i += GGML_F32_STEP) {
  309. for (int j = 0; j < GGML_F32_ARR; j++) {
  310. ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
  311. ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb);
  312. GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
  313. }
  314. }
  315. // leftovers
  316. for (int i = np; i < n; ++i) {
  317. y[i] = x[i]*s + b;
  318. }
  319. #endif
  320. #else
  321. // scalar
  322. for (int i = 0; i < n; ++i) {
  323. y[i] = x[i]*s + b;
  324. }
  325. #endif
  326. }
  327. //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
  328. inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
  329. #if defined(GGML_USE_ACCELERATE)
  330. vDSP_vsmul(y, 1, &v, y, 1, n);
  331. #elif defined(GGML_SIMD)
  332. #if defined(__ARM_FEATURE_SVE)
  333. const int sve_register_length = ggml_cpu_get_sve_cnt() * 8;
  334. const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16
  335. const int ggml_f32_step = 2 * ggml_f32_epr;
  336. GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
  337. const int np = (n & ~(ggml_f32_step - 1));
  338. svfloat32_t ay1;
  339. svfloat32_t ay2;
  340. for (int i = 0; i < np; i += ggml_f32_step) {
  341. ay1 = GGML_F32_VEC_LOAD(y + i);
  342. ay1 = GGML_F32_VEC_MUL(ay1, vx);
  343. GGML_F32_VEC_STORE(y + i, ay1);
  344. ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
  345. ay2 = GGML_F32_VEC_MUL(ay2, vx);
  346. GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
  347. }
  348. // leftovers
  349. // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
  350. if (np < n) {
  351. svbool_t pg = svwhilelt_b32(np, n);
  352. ay1 = svld1_f32(pg, y + np);
  353. ay1 = svmul_f32_m(pg, ay1, vx);
  354. svst1_f32(pg, y + np, ay1);
  355. }
  356. #else
  357. const int np = (n & ~(GGML_F32_STEP - 1));
  358. GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
  359. GGML_F32_VEC ay[GGML_F32_ARR];
  360. for (int i = 0; i < np; i += GGML_F32_STEP) {
  361. for (int j = 0; j < GGML_F32_ARR; j++) {
  362. ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
  363. ay[j] = GGML_F32_VEC_MUL(ay[j], vx);
  364. GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
  365. }
  366. }
  367. // leftovers
  368. for (int i = np; i < n; ++i) {
  369. y[i] *= v;
  370. }
  371. #endif
  372. #else
  373. // scalar
  374. for (int i = 0; i < n; ++i) {
  375. y[i] *= v;
  376. }
  377. #endif
  378. }
  379. inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
  380. #if defined(GGML_SIMD)
  381. const int np = (n & ~(GGML_F16_STEP - 1));
  382. GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
  383. GGML_F16_VEC ay[GGML_F16_ARR];
  384. for (int i = 0; i < np; i += GGML_F16_STEP) {
  385. for (int j = 0; j < GGML_F16_ARR; j++) {
  386. ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
  387. ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
  388. GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
  389. }
  390. }
  391. // leftovers
  392. for (int i = np; i < n; ++i) {
  393. y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
  394. }
  395. #else
  396. // scalar
  397. for (int i = 0; i < n; ++i) {
  398. y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
  399. }
  400. #endif
  401. }
  402. inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
  403. inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
  404. inline static void ggml_vec_sqr_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  405. for (int i = 0; i < n; ++i) {
  406. float v = GGML_CPU_FP16_TO_FP32(x[i]);
  407. y[i] = GGML_CPU_FP32_TO_FP16(v*v);
  408. }
  409. }
  410. inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
  411. inline static void ggml_vec_sqrt_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  412. for (int i = 0; i < n; ++i) {
  413. y[i] = GGML_CPU_FP32_TO_FP16(sqrtf(GGML_CPU_FP16_TO_FP32(x[i])));
  414. }
  415. }
  416. inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
  417. inline static void ggml_vec_log_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  418. for (int i = 0; i < n; ++i) {
  419. y[i] = GGML_CPU_FP32_TO_FP16(logf(GGML_CPU_FP16_TO_FP32(x[i])));
  420. }
  421. }
  422. inline static void ggml_vec_sin_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]); }
  423. inline static void ggml_vec_sin_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  424. for (int i = 0; i < n; ++i) {
  425. y[i] = GGML_CPU_FP32_TO_FP16(sinf(GGML_CPU_FP16_TO_FP32(x[i])));
  426. }
  427. }
  428. inline static void ggml_vec_cos_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]); }
  429. inline static void ggml_vec_cos_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  430. for (int i = 0; i < n; ++i) {
  431. y[i] = GGML_CPU_FP32_TO_FP16(cosf(GGML_CPU_FP16_TO_FP32(x[i])));
  432. }
  433. }
  434. inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
  435. inline static void ggml_vec_abs_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  436. for (int i = 0; i < n; ++i) {
  437. y[i] = GGML_CPU_FP32_TO_FP16(fabsf(GGML_CPU_FP16_TO_FP32(x[i])));
  438. }
  439. }
  440. inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
  441. inline static void ggml_vec_sgn_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  442. for (int i = 0; i < n; ++i) {
  443. float v = GGML_CPU_FP16_TO_FP32(x[i]);
  444. y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? 1.f : ((v < 0.f) ? -1.f : 0.f));
  445. }
  446. }
  447. inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
  448. inline static void ggml_vec_step_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  449. for (int i = 0; i < n; ++i) {
  450. y[i] = GGML_CPU_FP32_TO_FP16((GGML_CPU_FP16_TO_FP32(x[i]) > 0.f) ? 1.f : 0.f);
  451. }
  452. }
  453. inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
  454. inline static void ggml_vec_tanh_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  455. for (int i = 0; i < n; ++i) {
  456. y[i] = GGML_CPU_FP32_TO_FP16(tanhf(GGML_CPU_FP16_TO_FP32(x[i])));
  457. }
  458. }
  459. inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }
  460. inline static void ggml_vec_elu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  461. for (int i = 0; i < n; ++i) {
  462. y[i] = GGML_CPU_FP32_TO_FP16(expm1f(GGML_CPU_FP16_TO_FP32(x[i])));
  463. }
  464. }
  465. inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
  466. inline static void ggml_vec_relu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  467. for (int i = 0; i < n; ++i) {
  468. float v = GGML_CPU_FP16_TO_FP32(x[i]);
  469. y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v : 0.f);
  470. }
  471. }
  472. inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
  473. inline static void ggml_vec_leaky_relu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const float ns) {
  474. for (int i = 0; i < n; ++i) {
  475. float v = GGML_CPU_FP16_TO_FP32(x[i]);
  476. y[i] = GGML_CPU_FP32_TO_FP16(((v > 0.f) ? v : 0.f) + ns * ((v < 0.0f) ? v : 0.f));
  477. }
  478. }
  479. inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
  480. inline static void ggml_vec_sigmoid_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  481. for (int i = 0; i < n; ++i) {
  482. y[i] = GGML_CPU_FP32_TO_FP16(1.f / (1.f + expf(-GGML_CPU_FP16_TO_FP32(x[i]))));
  483. }
  484. }
  485. // TODO: optimize performance
  486. inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
  487. inline static void ggml_vec_hardswish_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  488. for (int i = 0; i < n; ++i) {
  489. float v = GGML_CPU_FP16_TO_FP32(x[i]);
  490. y[i] = GGML_CPU_FP32_TO_FP16(v * fminf(1.0f, fmaxf(0.0f, (v + 3.0f) / 6.0f)));
  491. }
  492. }
  493. inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
  494. inline static void ggml_vec_hardsigmoid_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  495. for (int i = 0; i < n; ++i) {
  496. y[i] = GGML_CPU_FP32_TO_FP16(fminf(1.0f, fmaxf(0.0f, (GGML_CPU_FP16_TO_FP32(x[i]) + 3.0f) / 6.0f)));
  497. }
  498. }
  499. inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
  500. inline static void ggml_vec_exp_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  501. for (int i = 0; i < n; ++i) {
  502. y[i] = GGML_CPU_FP32_TO_FP16(expf(GGML_CPU_FP16_TO_FP32(x[i])));
  503. }
  504. }
  505. static const float GELU_COEF_A = 0.044715f;
  506. static const float GELU_QUICK_COEF = -1.702f;
  507. static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
  508. static const float SQRT_2_INV = 0.70710678118654752440084436210484f;
  509. inline static float ggml_gelu_f32(float x) {
  510. return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
  511. }
  512. inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  513. const uint16_t * i16 = (const uint16_t *) x;
  514. for (int i = 0; i < n; ++i) {
  515. y[i] = ggml_table_gelu_f16[i16[i]];
  516. }
  517. }
  518. inline static void ggml_vec_gelu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  519. for (int i = 0; i < n; ++i) {
  520. float xi = GGML_CPU_FP16_TO_FP32(x[i]);
  521. float res = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
  522. y[i] = GGML_CPU_FP32_TO_FP16(res);
  523. }
  524. }
  525. #ifdef GGML_GELU_FP16
  526. inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
  527. uint16_t t;
  528. for (int i = 0; i < n; ++i) {
  529. if (x[i] <= -10.0f) {
  530. y[i] = 0.0f;
  531. } else if (x[i] >= 10.0f) {
  532. y[i] = x[i];
  533. } else {
  534. ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
  535. memcpy(&t, &fp16, sizeof(uint16_t));
  536. y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[t]);
  537. }
  538. }
  539. }
  540. #else
  541. inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
  542. for (int i = 0; i < n; ++i) {
  543. y[i] = ggml_gelu_f32(x[i]);
  544. }
  545. }
  546. #endif
  547. inline static void ggml_vec_gelu_erf_f32(const int n, float * y, const float * x) {
  548. for (int i = 0; i < n; ++i) {
  549. float xi = x[i];
  550. y[i] = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
  551. }
  552. }
  553. inline static float ggml_gelu_quick_f32(float x) {
  554. return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
  555. }
  556. //inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  557. // const uint16_t * i16 = (const uint16_t *) x;
  558. // for (int i = 0; i < n; ++i) {
  559. // y[i] = ggml_table_gelu_quick_f16[i16[i]];
  560. // }
  561. //}
  562. #ifdef GGML_GELU_QUICK_FP16
  563. inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
  564. uint16_t t;
  565. for (int i = 0; i < n; ++i) {
  566. ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
  567. memcpy(&t, &fp16, sizeof(uint16_t));
  568. y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]);
  569. }
  570. }
  571. #else
  572. inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
  573. for (int i = 0; i < n; ++i) {
  574. y[i] = ggml_gelu_quick_f32(x[i]);
  575. }
  576. }
  577. #endif
  578. inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  579. for (int i = 0; i < n; ++i) {
  580. float v = GGML_CPU_FP16_TO_FP32(x[i]);
  581. y[i] = GGML_CPU_FP32_TO_FP16(v*(1.0f/(1.0f+expf(GELU_QUICK_COEF*v))));
  582. }
  583. }
  584. // Sigmoid Linear Unit (SiLU) function
  585. inline static float ggml_silu_f32(float x) {
  586. return x/(1.0f + expf(-x));
  587. }
  588. inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) {
  589. float v = GGML_CPU_FP16_TO_FP32(x);
  590. return GGML_CPU_FP32_TO_FP16(v/(1.0f + expf(-v)));
  591. }
  592. #if __FINITE_MATH_ONLY__
  593. #error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix"
  594. #error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461"
  595. #endif
  596. /* Below function was borrowed from the GitHub repository:
  597. https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp */
  598. #if defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
  599. inline static svfloat32_t exp_ps_sve(svbool_t pg, svfloat32_t src) {
  600. // Constants
  601. const svfloat32_t log2_e = svdup_n_f32(1.4426950409f);
  602. const svfloat32_t ln2 = svdup_n_f32(0.6931473921f);
  603. const svfloat32_t half_ln2_sq = svdup_n_f32(0.2413862043f);
  604. const svuint32_t not_mask17 = svdup_n_u32(~((1u << 17) - 1));
  605. const svfloat32_t one = svdup_n_f32(1.0f);
  606. const svfloat32_t inactive1 = svdup_n_f32(0.0f);
  607. const svint32_t inactive2 = svdup_n_s32(0);
  608. // Algorithm starts here
  609. svfloat32_t t0 = svmul_f32_m(pg, src, log2_e); // y = x * log2(e)
  610. svfloat32_t t1 = svrintm_f32_m(inactive1, pg, t0); // rount to int (float)
  611. svint32_t t2 = svcvt_s32_f32_m(inactive2, pg, t1); // n
  612. t1 = svsub_f32_m(pg, t0, t1); // a = y - floor(y)
  613. t1 = svadd_f32_m(pg, t1, one); // b = a + 1
  614. svuint32_t t3 = svlsr_n_u32_m(pg, svreinterpret_u32_f32(t1), 17); // v = b >> 17 (u32)
  615. svfloat32_t t4 = svexpa_f32(t3); // c = fexpa(v)
  616. t4 = svscale_f32_m(pg, t4, t2); // fexpa(v) * 2^(n)
  617. // and_(t2.d, t1.d, not_mask17.d)
  618. svfloat32_t t5 = svreinterpret_f32_u32(svand_u32_m(pg, svreinterpret_u32_f32(t1), not_mask17));
  619. t5 = svsub_f32_m(pg, t1, t5); // z
  620. t0 = svmla_f32_m(pg, ln2, t5, half_ln2_sq); // ln2 + half_ln2_sq * z
  621. t0 = svmla_f32_m(pg, one, t5, t0); // 1 + (ln2 * z) + (half_ln2_sq * z * z)
  622. t0 = svmul_f32_m(pg, t0, t4); // Final result
  623. return t0;
  624. }
  625. #endif
  626. #if defined(__ARM_NEON) && defined(__aarch64__)
  627. // adapted from arm limited optimized routine
  628. // the maximum error is 1.45358 plus 0.5 ulps
  629. // numbers above 88.38 will flush to infinity
  630. // numbers beneath -103.97 will flush to zero
  631. inline static float32x4_t ggml_v_expf(float32x4_t x) {
  632. const float32x4_t r = vdupq_n_f32(0x1.8p23f);
  633. const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
  634. const float32x4_t n = vsubq_f32(z, r);
  635. const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
  636. vdupq_n_f32(0x1.7f7d1cp-20f));
  637. const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
  638. const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
  639. const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
  640. const float32x4_t u = vmulq_f32(b, b);
  641. const float32x4_t j = vfmaq_f32(
  642. vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
  643. vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
  644. vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
  645. if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
  646. return vfmaq_f32(k, j, k);
  647. const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
  648. const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
  649. const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
  650. return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
  651. vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
  652. }
  653. // computes silu x/(1+exp(-x)) in single precision vector
  654. inline static float32x4_t ggml_v_silu(float32x4_t x) {
  655. const float32x4_t one = vdupq_n_f32(1.0f);
  656. const float32x4_t zero = vdupq_n_f32(0.0f);
  657. const float32x4_t neg_x = vsubq_f32(zero, x);
  658. const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
  659. const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
  660. return vdivq_f32(x, one_plus_exp_neg_x);
  661. }
  662. #elif defined(__AVX512F__) && defined(__AVX512DQ__)
  663. // adapted from arm limited optimized routine
  664. // the maximum error is 1.45358 plus 0.5 ulps
  665. // numbers above 88.38 will flush to infinity
  666. // numbers beneath -103.97 will flush to zero
  667. inline static __m512 ggml_v_expf(__m512 x) {
  668. const __m512 r = _mm512_set1_ps(0x1.8p23f);
  669. const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
  670. const __m512 n = _mm512_sub_ps(z, r);
  671. const __m512 b =
  672. _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
  673. _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
  674. const __mmask16 d =
  675. _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
  676. const __m512 u = _mm512_mul_ps(b, b);
  677. const __m512 j = _mm512_fmadd_ps(
  678. _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
  679. _mm512_set1_ps(0x1.573e2ep-5f)),
  680. u,
  681. _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
  682. _mm512_set1_ps(0x1.fffdb6p-2f))),
  683. u,
  684. _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
  685. const __m512 res = _mm512_scalef_ps(j, n);
  686. if (_mm512_kortestz(d, d))
  687. return res;
  688. const __m512 zero = _mm512_setzero_ps();
  689. const __m512 alt = _mm512_mask_blend_ps(
  690. _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
  691. return _mm512_mask_blend_ps(d, res, alt);
  692. }
  693. // computes silu x/(1+exp(-x)) in single precision vector
  694. inline static __m512 ggml_v_silu(__m512 x) {
  695. const __m512 one = _mm512_set1_ps(1);
  696. const __m512 zero = _mm512_setzero_ps();
  697. const __m512 neg_x = _mm512_sub_ps(zero, x);
  698. const __m512 exp_neg_x = ggml_v_expf(neg_x);
  699. const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
  700. return _mm512_div_ps(x, one_plus_exp_neg_x);
  701. }
  702. #elif defined(__AVX2__) && defined(__FMA__)
  703. // adapted from arm limited optimized routine
  704. // the maximum error is 1.45358 plus 0.5 ulps
  705. // numbers above 88.38 will flush to infinity
  706. // numbers beneath -103.97 will flush to zero
  707. inline static __m256 ggml_v_expf(__m256 x) {
  708. const __m256 r = _mm256_set1_ps(0x1.8p23f);
  709. const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
  710. const __m256 n = _mm256_sub_ps(z, r);
  711. const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
  712. _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
  713. const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
  714. const __m256 k = _mm256_castsi256_ps(
  715. _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
  716. const __m256i c = _mm256_castps_si256(
  717. _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
  718. _mm256_set1_ps(126), _CMP_GT_OQ));
  719. const __m256 u = _mm256_mul_ps(b, b);
  720. const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
  721. _mm256_set1_ps(0x1.573e2ep-5f)), u,
  722. _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
  723. _mm256_set1_ps(0x1.fffdb6p-2f))),
  724. u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
  725. if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
  726. return _mm256_fmadd_ps(j, k, k);
  727. const __m256i g = _mm256_and_si256(
  728. _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
  729. _mm256_set1_epi32(0x82000000u));
  730. const __m256 s1 =
  731. _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
  732. const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
  733. const __m256i d = _mm256_castps_si256(
  734. _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
  735. _mm256_set1_ps(192), _CMP_GT_OQ));
  736. return _mm256_or_ps(
  737. _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
  738. _mm256_andnot_ps(
  739. _mm256_castsi256_ps(d),
  740. _mm256_or_ps(
  741. _mm256_and_ps(_mm256_castsi256_ps(c),
  742. _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
  743. _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
  744. }
  745. // computes silu x/(1+exp(-x)) in single precision vector
  746. inline static __m256 ggml_v_silu(__m256 x) {
  747. const __m256 one = _mm256_set1_ps(1);
  748. const __m256 zero = _mm256_setzero_ps();
  749. const __m256 neg_x = _mm256_sub_ps(zero, x);
  750. const __m256 exp_neg_x = ggml_v_expf(neg_x);
  751. const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
  752. return _mm256_div_ps(x, one_plus_exp_neg_x);
  753. }
  754. #elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
  755. #if defined(__FMA__)
  756. #define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
  757. #define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
  758. #else
  759. #define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
  760. #define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
  761. #endif
  762. // adapted from arm limited optimized routine
  763. // the maximum error is 1.45358 plus 0.5 ulps
  764. // numbers above 88.38 will flush to infinity
  765. // numbers beneath -103.97 will flush to zero
  766. inline static __m128 ggml_v_expf(__m128 x) {
  767. const __m128 r = _mm_set1_ps(0x1.8p23f);
  768. const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
  769. const __m128 n = _mm_sub_ps(z, r);
  770. const __m128 b =
  771. NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
  772. const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
  773. const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
  774. const __m128i c =
  775. _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
  776. const __m128 u = _mm_mul_ps(b, b);
  777. const __m128 j =
  778. MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
  779. MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
  780. u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
  781. if (!_mm_movemask_epi8(c))
  782. return MADD128(j, k, k);
  783. const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
  784. _mm_set1_epi32(0x82000000u));
  785. const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
  786. const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
  787. const __m128i d =
  788. _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
  789. return _mm_or_ps(
  790. _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
  791. _mm_andnot_ps(_mm_castsi128_ps(d),
  792. _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
  793. _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
  794. }
  795. // computes silu x/(1+exp(-x)) in single precision vector
  796. inline static __m128 ggml_v_silu(__m128 x) {
  797. const __m128 one = _mm_set1_ps(1);
  798. const __m128 zero = _mm_setzero_ps();
  799. const __m128 neg_x = _mm_sub_ps(zero, x);
  800. const __m128 exp_neg_x = ggml_v_expf(neg_x);
  801. const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
  802. return _mm_div_ps(x, one_plus_exp_neg_x);
  803. }
  804. #endif // __ARM_NEON / __AVX2__ / __SSE2__
  805. inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  806. for (int i = 0; i < n; ++i) {
  807. y[i] = ggml_silu_f16(x[i]);
  808. }
  809. }
  810. inline static float ggml_silu_backward_f32(float x, float dy) {
  811. const float s = 1.0f/(1.0f + expf(-x));
  812. return dy*s*(1.0f + x*(1.0f - s));
  813. }
  814. inline static ggml_fp16_t ggml_silu_backward_f16(ggml_fp16_t x, ggml_fp16_t dy) {
  815. const float v = GGML_CPU_FP16_TO_FP32(x);
  816. const float s = 1.0f/(1.0f + expf(-v));
  817. return GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(dy)*s*(1.0f + v*(1.0f - s)));
  818. }
  819. inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
  820. for (int i = 0; i < n; ++i) {
  821. dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
  822. }
  823. }
  824. inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, const ggml_fp16_t * x, const ggml_fp16_t * dy) {
  825. for (int i = 0; i < n; ++i) {
  826. dx[i] = ggml_silu_backward_f16(x[i], dy[i]);
  827. }
  828. }
  829. inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x, const float * g) {
  830. for (int i = 0; i < n; ++i) {
  831. y[i] = (x[i] > 0.f) ? x[i] * g[i] : 0.f;
  832. }
  833. }
  834. inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
  835. for (int i = 0; i < n; ++i) {
  836. float v = GGML_CPU_FP16_TO_FP32(x[i]);
  837. y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v * GGML_CPU_FP16_TO_FP32(g[i]) : 0.f);
  838. }
  839. }
  840. #ifdef GGML_GELU_FP16
  841. inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
  842. uint16_t t;
  843. for (int i = 0; i < n; ++i) {
  844. if (x[i] <= -10.0f) {
  845. y[i] = 0.0f;
  846. } else if (x[i] >= 10.0f) {
  847. y[i] = x[i] * g[i];
  848. } else {
  849. ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
  850. memcpy(&t, &fp16, sizeof(uint16_t));
  851. y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[t]) * g[i];
  852. }
  853. }
  854. }
  855. #else
  856. inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
  857. for (int i = 0; i < n; ++i) {
  858. y[i] = ggml_gelu_f32(x[i]) * g[i];
  859. }
  860. }
  861. #endif
  862. inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
  863. const uint16_t * i16 = (const uint16_t *) x;
  864. for (int i = 0; i < n; ++i) {
  865. float v = GGML_CPU_FP16_TO_FP32(g[i]);
  866. y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * v);
  867. }
  868. }
  869. void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g);
  870. inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
  871. for (int i = 0; i < n; ++i) {
  872. float xi = GGML_CPU_FP16_TO_FP32(x[i]);
  873. float gi = GGML_CPU_FP16_TO_FP32(g[i]);
  874. y[i] = GGML_CPU_FP32_TO_FP16((xi/(1.0f + expf(-xi))) * gi);
  875. }
  876. }
  877. inline static void ggml_vec_geglu_erf_f32(const int n, float * y, const float * x, const float * g) {
  878. for (int i = 0; i < n; ++i) {
  879. float xi = x[i];
  880. y[i] = 0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * g[i];
  881. }
  882. }
  883. inline static void ggml_vec_geglu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
  884. for (int i = 0; i < n; ++i) {
  885. float xi = GGML_CPU_FP16_TO_FP32(x[i]);
  886. float gi = GGML_CPU_FP16_TO_FP32(g[i]);
  887. y[i] = GGML_CPU_FP32_TO_FP16(0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * gi);
  888. }
  889. }
  890. #ifdef GGML_GELU_QUICK_FP16
  891. inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
  892. uint16_t t;
  893. for (int i = 0; i < n; ++i) {
  894. ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
  895. memcpy(&t, &fp16, sizeof(uint16_t));
  896. y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]) * g[i];
  897. }
  898. }
  899. #else
  900. inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
  901. for (int i = 0; i < n; ++i) {
  902. y[i] = ggml_gelu_quick_f32(x[i]) * g[i];
  903. }
  904. }
  905. #endif
  906. inline static void ggml_vec_geglu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
  907. const uint16_t * i16 = (const uint16_t *) x;
  908. for (int i = 0; i < n; ++i) {
  909. float v = GGML_CPU_FP16_TO_FP32(g[i]);
  910. y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[i16[i]]) * v);
  911. }
  912. }
  913. inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
  914. #ifndef GGML_USE_ACCELERATE
  915. ggml_float sum = 0.0;
  916. for (int i = 0; i < n; ++i) {
  917. sum += (ggml_float)x[i];
  918. }
  919. *s = (float)sum;
  920. #else
  921. vDSP_sve(x, 1, s, n);
  922. #endif
  923. }
  924. inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
  925. ggml_float sum = 0.0;
  926. for (int i = 0; i < n; ++i) {
  927. sum += (ggml_float)x[i];
  928. }
  929. *s = sum;
  930. }
  931. inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) {
  932. float sum = 0.0f;
  933. for (int i = 0; i < n; ++i) {
  934. sum += GGML_CPU_FP16_TO_FP32(x[i]);
  935. }
  936. *s = sum;
  937. }
  938. inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) {
  939. float sum = 0.0f;
  940. for (int i = 0; i < n; ++i) {
  941. sum += GGML_BF16_TO_FP32(x[i]);
  942. }
  943. *s = sum;
  944. }
  945. inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
  946. #ifndef GGML_USE_ACCELERATE
  947. float max = -INFINITY;
  948. for (int i = 0; i < n; ++i) {
  949. max = MAX(max, x[i]);
  950. }
  951. *s = max;
  952. #else
  953. vDSP_maxv(x, 1, s, n);
  954. #endif
  955. }
  956. inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) {
  957. ggml_vec_norm_f32(n, s, x);
  958. *s = 1.f/(*s);
  959. }
  960. inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {
  961. float max = -INFINITY;
  962. int idx = 0;
  963. for (int i = 0; i < n; ++i) {
  964. max = MAX(max, x[i]);
  965. if (max == x[i]) { idx = i; }
  966. }
  967. *s = idx;
  968. }
  969. #ifdef __cplusplus
  970. }
  971. #endif