vec.h 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106
  1. // Vectorized functions for fundamental operations
  2. #pragma once
  3. #include "ggml-impl.h"
  4. #include "simd-mappings.h"
  5. #include "ggml.h"
  6. #include "ggml-cpu.h"
  7. #if defined(GGML_USE_ACCELERATE)
  8. #include <Accelerate/Accelerate.h>
  9. #endif
  10. // floating point type used to accumulate sums
  11. typedef double ggml_float;
  12. #define GGML_GELU_FP16
  13. #define GGML_GELU_QUICK_FP16
  14. #define GGML_SOFT_MAX_UNROLL 4
  15. #define GGML_VEC_DOT_UNROLL 2
  16. #define GGML_VEC_MAD_UNROLL 32
  17. #ifdef __cplusplus
  18. extern "C" {
  19. #endif
  20. //
  21. // global data
  22. //
  23. // precomputed gelu table for f16 (128 KB)
  24. extern ggml_fp16_t ggml_table_gelu_f16[1 << 16];
  25. // precomputed quick gelu table for f16 (128 KB)
  26. extern ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
  27. //
  28. // fundamental operations
  29. //
  30. void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
  31. void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc);
  32. void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
  33. void ggml_vec_silu_f32(const int n, float * y, const float * x);
  34. ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max);
  35. ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max);
  36. inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
  37. inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
  38. inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
  39. inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
  40. inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
  41. inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
  42. inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
  43. inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
  44. for (int i = 0; i < n; ++i) {
  45. z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) + GGML_CPU_FP16_TO_FP32(y[i]));
  46. }
  47. }
  48. inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
  49. inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
  50. inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
  51. inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; }
  52. inline static void ggml_vec_sub_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
  53. for (int i = 0; i < n; ++i) {
  54. z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) - GGML_CPU_FP16_TO_FP32(y[i]));
  55. }
  56. }
  57. inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
  58. inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
  59. inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
  60. inline static void ggml_vec_neg_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  61. for (int i = 0; i < n; ++i) {
  62. y[i] = GGML_CPU_FP32_TO_FP16(-GGML_CPU_FP16_TO_FP32(x[i]));
  63. }
  64. }
  65. inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
  66. inline static void ggml_vec_mul_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
  67. for (int i = 0; i < n; ++i) {
  68. z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) * GGML_CPU_FP16_TO_FP32(y[i]));
  69. }
  70. }
  71. inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
  72. inline static void ggml_vec_div_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
  73. for (int i = 0; i < n; ++i) {
  74. z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) / GGML_CPU_FP16_TO_FP32(y[i]));
  75. }
  76. }
  77. // compute GGML_VEC_DOT_UNROLL dot products at once
  78. // xs - x row stride in bytes
  79. inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GGML_RESTRICT s, void * GGML_RESTRICT xv, ggml_fp16_t * GGML_RESTRICT y) {
  80. ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
  81. ggml_fp16_t * GGML_RESTRICT x[GGML_VEC_DOT_UNROLL];
  82. for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
  83. x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
  84. }
  85. #if defined(GGML_SIMD)
  86. const int np = (n & ~(GGML_F16_STEP - 1));
  87. GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
  88. GGML_F16_VEC ax[GGML_F16_ARR];
  89. GGML_F16_VEC ay[GGML_F16_ARR];
  90. for (int i = 0; i < np; i += GGML_F16_STEP) {
  91. for (int j = 0; j < GGML_F16_ARR; j++) {
  92. ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
  93. for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
  94. ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
  95. sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
  96. }
  97. }
  98. }
  99. // reduce sum0..sum3 to sum0
  100. for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
  101. GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
  102. }
  103. // leftovers
  104. for (int i = np; i < n; ++i) {
  105. for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
  106. sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
  107. }
  108. }
  109. #else
  110. for (int i = 0; i < n; ++i) {
  111. for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
  112. sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
  113. }
  114. }
  115. #endif
  116. for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
  117. s[i] = (float)sumf[i];
  118. }
  119. }
  120. inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const float * GGML_RESTRICT x, const float v) {
  121. #if defined(GGML_SIMD)
  122. #if defined(__ARM_FEATURE_SVE)
  123. const int sve_register_length = ggml_cpu_get_sve_cnt() * 8;
  124. const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16
  125. const int ggml_f32_step = 8 * ggml_f32_epr; // choose 8 SVE registers
  126. GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
  127. const int np = (n & ~(ggml_f32_step - 1));
  128. svfloat32_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
  129. svfloat32_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
  130. for (int i = 0; i < np; i += ggml_f32_step) {
  131. ax1 = GGML_F32_VEC_LOAD(x + i);
  132. ay1 = GGML_F32_VEC_LOAD(y + i);
  133. ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
  134. GGML_F32_VEC_STORE(y + i, ay1);
  135. ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
  136. ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
  137. ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx);
  138. GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
  139. ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
  140. ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
  141. ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx);
  142. GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);
  143. ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
  144. ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
  145. ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx);
  146. GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);
  147. ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
  148. ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
  149. ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx);
  150. GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);
  151. ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
  152. ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
  153. ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx);
  154. GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);
  155. ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
  156. ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
  157. ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx);
  158. GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);
  159. ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
  160. ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
  161. ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx);
  162. GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
  163. }
  164. // leftovers
  165. // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
  166. const int np2 = (n & ~(ggml_f32_epr - 1));
  167. for (int i = np; i < np2; i += ggml_f32_epr) {
  168. ax1 = GGML_F32_VEC_LOAD(x + i);
  169. ay1 = GGML_F32_VEC_LOAD(y + i);
  170. ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
  171. GGML_F32_VEC_STORE(y + i, ay1);
  172. }
  173. // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
  174. if (np2 < n) {
  175. svbool_t pg =svwhilelt_b32(np2, n);
  176. ax1 = svld1_f32(pg, x + np2);
  177. ay1 = svld1_f32(pg, y + np2);
  178. ay1 = svmad_f32_m(pg, ax1, vx, ay1);
  179. svst1_f32(pg, y + np2, ay1);
  180. }
  181. #else
  182. const int np = (n & ~(GGML_F32_STEP - 1));
  183. GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
  184. GGML_F32_VEC ax[GGML_F32_ARR];
  185. GGML_F32_VEC ay[GGML_F32_ARR];
  186. for (int i = 0; i < np; i += GGML_F32_STEP) {
  187. for (int j = 0; j < GGML_F32_ARR; j++) {
  188. ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
  189. ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
  190. ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
  191. GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
  192. }
  193. }
  194. // leftovers
  195. for (int i = np; i < n; ++i) {
  196. y[i] += x[i]*v;
  197. }
  198. #endif
  199. #else
  200. // scalar
  201. for (int i = 0; i < n; ++i) {
  202. y[i] += x[i]*v;
  203. }
  204. #endif
  205. }
  206. inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
  207. #if defined(GGML_SIMD)
  208. const int np = (n & ~(GGML_F16_STEP - 1));
  209. GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
  210. GGML_F16_VEC ax[GGML_F16_ARR];
  211. GGML_F16_VEC ay[GGML_F16_ARR];
  212. for (int i = 0; i < np; i += GGML_F16_STEP) {
  213. for (int j = 0; j < GGML_F16_ARR; j++) {
  214. ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
  215. ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
  216. ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
  217. GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
  218. }
  219. }
  220. // leftovers
  221. for (int i = np; i < n; ++i) {
  222. y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
  223. }
  224. #else
  225. // scalar
  226. for (int i = 0; i < n; ++i) {
  227. y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
  228. }
  229. #endif
  230. }
  231. // xs and vs are byte strides of x and v
  232. inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * GGML_RESTRICT y, const float * GGML_RESTRICT xv, const float * GGML_RESTRICT vv) {
  233. const float * GGML_RESTRICT x[GGML_VEC_MAD_UNROLL];
  234. const float * GGML_RESTRICT v[GGML_VEC_MAD_UNROLL];
  235. for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) {
  236. x[i] = (const float *) ((const char *) xv + i*xs);
  237. v[i] = (const float *) ((const char *) vv + i*vs);
  238. }
  239. #if defined(GGML_SIMD)
  240. #if defined(__ARM_FEATURE_SVE)
  241. // scalar Route to scalar implementation //TODO: Write SVE code
  242. for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
  243. for (int i = 0; i < n; ++i) {
  244. y[i] += x[k][i]*v[k][0];
  245. }
  246. }
  247. #else
  248. const int np = (n & ~(GGML_F32_STEP - 1));
  249. GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
  250. for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
  251. vx[k] = GGML_F32_VEC_SET1(v[k][0]);
  252. }
  253. GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
  254. GGML_F32_VEC ay[GGML_F32_ARR];
  255. for (int i = 0; i < np; i += GGML_F32_STEP) {
  256. for (int j = 0; j < GGML_F32_ARR; j++) {
  257. ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
  258. for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
  259. ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
  260. ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
  261. }
  262. GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
  263. }
  264. }
  265. // leftovers
  266. for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
  267. for (int i = np; i < n; ++i) {
  268. y[i] += x[k][i]*v[k][0];
  269. }
  270. }
  271. #endif
  272. #else
  273. // scalar
  274. for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
  275. for (int i = 0; i < n; ++i) {
  276. y[i] += x[k][i]*v[k][0];
  277. }
  278. }
  279. #endif
  280. }
  281. inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) {
  282. #if defined(GGML_USE_ACCELERATE)
  283. vDSP_vsmsa(x, 1, &s, &b, y, 1, n);
  284. #elif defined(GGML_SIMD)
  285. #if defined(__ARM_FEATURE_SVE)
  286. // scalar ; TODO: Write SVE code
  287. for (int i = 0; i < n; ++i) {
  288. y[i] = x[i]*s + b;
  289. }
  290. #else
  291. const int np = (n & ~(GGML_F32_STEP - 1));
  292. GGML_F32_VEC vs = GGML_F32_VEC_SET1(s);
  293. GGML_F32_VEC vb = GGML_F32_VEC_SET1(b);
  294. GGML_F32_VEC ay[GGML_F32_ARR];
  295. for (int i = 0; i < np; i += GGML_F32_STEP) {
  296. for (int j = 0; j < GGML_F32_ARR; j++) {
  297. ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
  298. ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb);
  299. GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
  300. }
  301. }
  302. // leftovers
  303. for (int i = np; i < n; ++i) {
  304. y[i] = x[i]*s + b;
  305. }
  306. #endif
  307. #else
  308. // scalar
  309. for (int i = 0; i < n; ++i) {
  310. y[i] = x[i]*s + b;
  311. }
  312. #endif
  313. }
  314. //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
  315. inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
  316. #if defined(GGML_USE_ACCELERATE)
  317. vDSP_vsmul(y, 1, &v, y, 1, n);
  318. #elif defined(GGML_SIMD)
  319. #if defined(__ARM_FEATURE_SVE)
  320. const int sve_register_length = ggml_cpu_get_sve_cnt() * 8;
  321. const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16
  322. const int ggml_f32_step = 2 * ggml_f32_epr;
  323. GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
  324. const int np = (n & ~(ggml_f32_step - 1));
  325. svfloat32_t ay1;
  326. svfloat32_t ay2;
  327. for (int i = 0; i < np; i += ggml_f32_step) {
  328. ay1 = GGML_F32_VEC_LOAD(y + i);
  329. ay1 = GGML_F32_VEC_MUL(ay1, vx);
  330. GGML_F32_VEC_STORE(y + i, ay1);
  331. ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
  332. ay2 = GGML_F32_VEC_MUL(ay2, vx);
  333. GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
  334. }
  335. // leftovers
  336. // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
  337. if (np < n) {
  338. svbool_t pg = svwhilelt_b32(np, n);
  339. ay1 = svld1_f32(pg, y + np);
  340. ay1 = svmul_f32_m(pg, ay1, vx);
  341. svst1_f32(pg, y + np, ay1);
  342. }
  343. #else
  344. const int np = (n & ~(GGML_F32_STEP - 1));
  345. GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
  346. GGML_F32_VEC ay[GGML_F32_ARR];
  347. for (int i = 0; i < np; i += GGML_F32_STEP) {
  348. for (int j = 0; j < GGML_F32_ARR; j++) {
  349. ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
  350. ay[j] = GGML_F32_VEC_MUL(ay[j], vx);
  351. GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
  352. }
  353. }
  354. // leftovers
  355. for (int i = np; i < n; ++i) {
  356. y[i] *= v;
  357. }
  358. #endif
  359. #else
  360. // scalar
  361. for (int i = 0; i < n; ++i) {
  362. y[i] *= v;
  363. }
  364. #endif
  365. }
  366. inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
  367. #if defined(GGML_SIMD)
  368. const int np = (n & ~(GGML_F16_STEP - 1));
  369. GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
  370. GGML_F16_VEC ay[GGML_F16_ARR];
  371. for (int i = 0; i < np; i += GGML_F16_STEP) {
  372. for (int j = 0; j < GGML_F16_ARR; j++) {
  373. ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
  374. ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
  375. GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
  376. }
  377. }
  378. // leftovers
  379. for (int i = np; i < n; ++i) {
  380. y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
  381. }
  382. #else
  383. // scalar
  384. for (int i = 0; i < n; ++i) {
  385. y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
  386. }
  387. #endif
  388. }
  389. inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
  390. inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
  391. inline static void ggml_vec_sqr_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  392. for (int i = 0; i < n; ++i) {
  393. float v = GGML_CPU_FP16_TO_FP32(x[i]);
  394. y[i] = GGML_CPU_FP32_TO_FP16(v*v);
  395. }
  396. }
  397. inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
  398. inline static void ggml_vec_sqrt_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  399. for (int i = 0; i < n; ++i) {
  400. y[i] = GGML_CPU_FP32_TO_FP16(sqrtf(GGML_CPU_FP16_TO_FP32(x[i])));
  401. }
  402. }
  403. inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
  404. inline static void ggml_vec_log_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  405. for (int i = 0; i < n; ++i) {
  406. y[i] = GGML_CPU_FP32_TO_FP16(logf(GGML_CPU_FP16_TO_FP32(x[i])));
  407. }
  408. }
  409. inline static void ggml_vec_sin_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]); }
  410. inline static void ggml_vec_sin_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  411. for (int i = 0; i < n; ++i) {
  412. y[i] = GGML_CPU_FP32_TO_FP16(sinf(GGML_CPU_FP16_TO_FP32(x[i])));
  413. }
  414. }
  415. inline static void ggml_vec_cos_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]); }
  416. inline static void ggml_vec_cos_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  417. for (int i = 0; i < n; ++i) {
  418. y[i] = GGML_CPU_FP32_TO_FP16(cosf(GGML_CPU_FP16_TO_FP32(x[i])));
  419. }
  420. }
  421. inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
  422. inline static void ggml_vec_abs_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  423. for (int i = 0; i < n; ++i) {
  424. y[i] = GGML_CPU_FP32_TO_FP16(fabsf(GGML_CPU_FP16_TO_FP32(x[i])));
  425. }
  426. }
  427. inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
  428. inline static void ggml_vec_sgn_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  429. for (int i = 0; i < n; ++i) {
  430. float v = GGML_CPU_FP16_TO_FP32(x[i]);
  431. y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? 1.f : ((v < 0.f) ? -1.f : 0.f));
  432. }
  433. }
  434. inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
  435. inline static void ggml_vec_step_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  436. for (int i = 0; i < n; ++i) {
  437. y[i] = GGML_CPU_FP32_TO_FP16((GGML_CPU_FP16_TO_FP32(x[i]) > 0.f) ? 1.f : 0.f);
  438. }
  439. }
  440. inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
  441. inline static void ggml_vec_tanh_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  442. for (int i = 0; i < n; ++i) {
  443. y[i] = GGML_CPU_FP32_TO_FP16(tanhf(GGML_CPU_FP16_TO_FP32(x[i])));
  444. }
  445. }
  446. inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }
  447. inline static void ggml_vec_elu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  448. for (int i = 0; i < n; ++i) {
  449. y[i] = GGML_CPU_FP32_TO_FP16(expm1f(GGML_CPU_FP16_TO_FP32(x[i])));
  450. }
  451. }
  452. inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
  453. inline static void ggml_vec_relu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  454. for (int i = 0; i < n; ++i) {
  455. float v = GGML_CPU_FP16_TO_FP32(x[i]);
  456. y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v : 0.f);
  457. }
  458. }
  459. inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
  460. inline static void ggml_vec_leaky_relu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const float ns) {
  461. for (int i = 0; i < n; ++i) {
  462. float v = GGML_CPU_FP16_TO_FP32(x[i]);
  463. y[i] = GGML_CPU_FP32_TO_FP16(((v > 0.f) ? v : 0.f) + ns * ((v < 0.0f) ? v : 0.f));
  464. }
  465. }
  466. inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
  467. inline static void ggml_vec_sigmoid_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  468. for (int i = 0; i < n; ++i) {
  469. y[i] = GGML_CPU_FP32_TO_FP16(1.f / (1.f + expf(-GGML_CPU_FP16_TO_FP32(x[i]))));
  470. }
  471. }
  472. // TODO: optimize performance
  473. inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
  474. inline static void ggml_vec_hardswish_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  475. for (int i = 0; i < n; ++i) {
  476. float v = GGML_CPU_FP16_TO_FP32(x[i]);
  477. y[i] = GGML_CPU_FP32_TO_FP16(v * fminf(1.0f, fmaxf(0.0f, (v + 3.0f) / 6.0f)));
  478. }
  479. }
  480. inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
  481. inline static void ggml_vec_hardsigmoid_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  482. for (int i = 0; i < n; ++i) {
  483. y[i] = GGML_CPU_FP32_TO_FP16(fminf(1.0f, fmaxf(0.0f, (GGML_CPU_FP16_TO_FP32(x[i]) + 3.0f) / 6.0f)));
  484. }
  485. }
  486. inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
  487. inline static void ggml_vec_exp_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  488. for (int i = 0; i < n; ++i) {
  489. y[i] = GGML_CPU_FP32_TO_FP16(expf(GGML_CPU_FP16_TO_FP32(x[i])));
  490. }
  491. }
  492. static const float GELU_COEF_A = 0.044715f;
  493. static const float GELU_QUICK_COEF = -1.702f;
  494. static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
  495. static const float SQRT_2_INV = 0.70710678118654752440084436210484f;
  496. inline static float ggml_gelu_f32(float x) {
  497. return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
  498. }
  499. inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  500. const uint16_t * i16 = (const uint16_t *) x;
  501. for (int i = 0; i < n; ++i) {
  502. y[i] = ggml_table_gelu_f16[i16[i]];
  503. }
  504. }
  505. inline static void ggml_vec_gelu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  506. for (int i = 0; i < n; ++i) {
  507. float xi = GGML_CPU_FP16_TO_FP32(x[i]);
  508. float res = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
  509. y[i] = GGML_CPU_FP32_TO_FP16(res);
  510. }
  511. }
  512. #ifdef GGML_GELU_FP16
  513. inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
  514. uint16_t t;
  515. for (int i = 0; i < n; ++i) {
  516. if (x[i] <= -10.0f) {
  517. y[i] = 0.0f;
  518. } else if (x[i] >= 10.0f) {
  519. y[i] = x[i];
  520. } else {
  521. ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
  522. memcpy(&t, &fp16, sizeof(uint16_t));
  523. y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[t]);
  524. }
  525. }
  526. }
  527. #else
  528. inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
  529. for (int i = 0; i < n; ++i) {
  530. y[i] = ggml_gelu_f32(x[i]);
  531. }
  532. }
  533. #endif
  534. inline static void ggml_vec_gelu_erf_f32(const int n, float * y, const float * x) {
  535. for (int i = 0; i < n; ++i) {
  536. float xi = x[i];
  537. y[i] = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
  538. }
  539. }
  540. inline static float ggml_gelu_quick_f32(float x) {
  541. return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
  542. }
  543. //inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  544. // const uint16_t * i16 = (const uint16_t *) x;
  545. // for (int i = 0; i < n; ++i) {
  546. // y[i] = ggml_table_gelu_quick_f16[i16[i]];
  547. // }
  548. //}
  549. #ifdef GGML_GELU_QUICK_FP16
  550. inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
  551. uint16_t t;
  552. for (int i = 0; i < n; ++i) {
  553. ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
  554. memcpy(&t, &fp16, sizeof(uint16_t));
  555. y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]);
  556. }
  557. }
  558. #else
  559. inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
  560. for (int i = 0; i < n; ++i) {
  561. y[i] = ggml_gelu_quick_f32(x[i]);
  562. }
  563. }
  564. #endif
  565. inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  566. for (int i = 0; i < n; ++i) {
  567. float v = GGML_CPU_FP16_TO_FP32(x[i]);
  568. y[i] = GGML_CPU_FP32_TO_FP16(v*(1.0f/(1.0f+expf(GELU_QUICK_COEF*v))));
  569. }
  570. }
  571. // Sigmoid Linear Unit (SiLU) function
  572. inline static float ggml_silu_f32(float x) {
  573. return x/(1.0f + expf(-x));
  574. }
  575. inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) {
  576. float v = GGML_CPU_FP16_TO_FP32(x);
  577. return GGML_CPU_FP32_TO_FP16(v/(1.0f + expf(-v)));
  578. }
  579. #if __FINITE_MATH_ONLY__
  580. #error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix"
  581. #error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461"
  582. #endif
  583. /* Below function was borrowed from the GitHub repository:
  584. https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp */
  585. #if defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
  586. inline static svfloat32_t exp_ps_sve(svbool_t pg, svfloat32_t src) {
  587. // Constants
  588. const svfloat32_t log2_e = svdup_n_f32(1.4426950409f);
  589. const svfloat32_t ln2 = svdup_n_f32(0.6931473921f);
  590. const svfloat32_t half_ln2_sq = svdup_n_f32(0.2413862043f);
  591. const svuint32_t not_mask17 = svdup_n_u32(~((1u << 17) - 1));
  592. const svfloat32_t one = svdup_n_f32(1.0f);
  593. const svfloat32_t inactive1 = svdup_n_f32(0.0f);
  594. const svint32_t inactive2 = svdup_n_s32(0);
  595. // Algorithm starts here
  596. svfloat32_t t0 = svmul_f32_m(pg, src, log2_e); // y = x * log2(e)
  597. svfloat32_t t1 = svrintm_f32_m(inactive1, pg, t0); // rount to int (float)
  598. svint32_t t2 = svcvt_s32_f32_m(inactive2, pg, t1); // n
  599. t1 = svsub_f32_m(pg, t0, t1); // a = y - floor(y)
  600. t1 = svadd_f32_m(pg, t1, one); // b = a + 1
  601. svuint32_t t3 = svlsr_n_u32_m(pg, svreinterpret_u32_f32(t1), 17); // v = b >> 17 (u32)
  602. svfloat32_t t4 = svexpa_f32(t3); // c = fexpa(v)
  603. t4 = svscale_f32_m(pg, t4, t2); // fexpa(v) * 2^(n)
  604. // and_(t2.d, t1.d, not_mask17.d)
  605. svfloat32_t t5 = svreinterpret_f32_u32(svand_u32_m(pg, svreinterpret_u32_f32(t1), not_mask17));
  606. t5 = svsub_f32_m(pg, t1, t5); // z
  607. t0 = svmla_f32_m(pg, ln2, t5, half_ln2_sq); // ln2 + half_ln2_sq * z
  608. t0 = svmla_f32_m(pg, one, t5, t0); // 1 + (ln2 * z) + (half_ln2_sq * z * z)
  609. t0 = svmul_f32_m(pg, t0, t4); // Final result
  610. return t0;
  611. }
  612. #endif
  613. #if defined(__ARM_NEON) && defined(__aarch64__)
  614. // adapted from arm limited optimized routine
  615. // the maximum error is 1.45358 plus 0.5 ulps
  616. // numbers above 88.38 will flush to infinity
  617. // numbers beneath -103.97 will flush to zero
  618. inline static float32x4_t ggml_v_expf(float32x4_t x) {
  619. const float32x4_t r = vdupq_n_f32(0x1.8p23f);
  620. const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
  621. const float32x4_t n = vsubq_f32(z, r);
  622. const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
  623. vdupq_n_f32(0x1.7f7d1cp-20f));
  624. const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
  625. const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
  626. const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
  627. const float32x4_t u = vmulq_f32(b, b);
  628. const float32x4_t j = vfmaq_f32(
  629. vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
  630. vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
  631. vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
  632. if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
  633. return vfmaq_f32(k, j, k);
  634. const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
  635. const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
  636. const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
  637. return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
  638. vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
  639. }
  640. // computes silu x/(1+exp(-x)) in single precision vector
  641. inline static float32x4_t ggml_v_silu(float32x4_t x) {
  642. const float32x4_t one = vdupq_n_f32(1.0f);
  643. const float32x4_t zero = vdupq_n_f32(0.0f);
  644. const float32x4_t neg_x = vsubq_f32(zero, x);
  645. const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
  646. const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
  647. return vdivq_f32(x, one_plus_exp_neg_x);
  648. }
  649. #elif defined(__AVX512F__) && defined(__AVX512DQ__)
  650. // adapted from arm limited optimized routine
  651. // the maximum error is 1.45358 plus 0.5 ulps
  652. // numbers above 88.38 will flush to infinity
  653. // numbers beneath -103.97 will flush to zero
  654. inline static __m512 ggml_v_expf(__m512 x) {
  655. const __m512 r = _mm512_set1_ps(0x1.8p23f);
  656. const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
  657. const __m512 n = _mm512_sub_ps(z, r);
  658. const __m512 b =
  659. _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
  660. _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
  661. const __mmask16 d =
  662. _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
  663. const __m512 u = _mm512_mul_ps(b, b);
  664. const __m512 j = _mm512_fmadd_ps(
  665. _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
  666. _mm512_set1_ps(0x1.573e2ep-5f)),
  667. u,
  668. _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
  669. _mm512_set1_ps(0x1.fffdb6p-2f))),
  670. u,
  671. _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
  672. const __m512 res = _mm512_scalef_ps(j, n);
  673. if (_mm512_kortestz(d, d))
  674. return res;
  675. const __m512 zero = _mm512_setzero_ps();
  676. const __m512 alt = _mm512_mask_blend_ps(
  677. _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
  678. return _mm512_mask_blend_ps(d, res, alt);
  679. }
  680. // computes silu x/(1+exp(-x)) in single precision vector
  681. inline static __m512 ggml_v_silu(__m512 x) {
  682. const __m512 one = _mm512_set1_ps(1);
  683. const __m512 zero = _mm512_setzero_ps();
  684. const __m512 neg_x = _mm512_sub_ps(zero, x);
  685. const __m512 exp_neg_x = ggml_v_expf(neg_x);
  686. const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
  687. return _mm512_div_ps(x, one_plus_exp_neg_x);
  688. }
  689. #elif defined(__AVX2__) && defined(__FMA__)
  690. // adapted from arm limited optimized routine
  691. // the maximum error is 1.45358 plus 0.5 ulps
  692. // numbers above 88.38 will flush to infinity
  693. // numbers beneath -103.97 will flush to zero
  694. inline static __m256 ggml_v_expf(__m256 x) {
  695. const __m256 r = _mm256_set1_ps(0x1.8p23f);
  696. const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
  697. const __m256 n = _mm256_sub_ps(z, r);
  698. const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
  699. _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
  700. const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
  701. const __m256 k = _mm256_castsi256_ps(
  702. _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
  703. const __m256i c = _mm256_castps_si256(
  704. _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
  705. _mm256_set1_ps(126), _CMP_GT_OQ));
  706. const __m256 u = _mm256_mul_ps(b, b);
  707. const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
  708. _mm256_set1_ps(0x1.573e2ep-5f)), u,
  709. _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
  710. _mm256_set1_ps(0x1.fffdb6p-2f))),
  711. u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
  712. if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
  713. return _mm256_fmadd_ps(j, k, k);
  714. const __m256i g = _mm256_and_si256(
  715. _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
  716. _mm256_set1_epi32(0x82000000u));
  717. const __m256 s1 =
  718. _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
  719. const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
  720. const __m256i d = _mm256_castps_si256(
  721. _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
  722. _mm256_set1_ps(192), _CMP_GT_OQ));
  723. return _mm256_or_ps(
  724. _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
  725. _mm256_andnot_ps(
  726. _mm256_castsi256_ps(d),
  727. _mm256_or_ps(
  728. _mm256_and_ps(_mm256_castsi256_ps(c),
  729. _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
  730. _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
  731. }
  732. // computes silu x/(1+exp(-x)) in single precision vector
  733. inline static __m256 ggml_v_silu(__m256 x) {
  734. const __m256 one = _mm256_set1_ps(1);
  735. const __m256 zero = _mm256_setzero_ps();
  736. const __m256 neg_x = _mm256_sub_ps(zero, x);
  737. const __m256 exp_neg_x = ggml_v_expf(neg_x);
  738. const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
  739. return _mm256_div_ps(x, one_plus_exp_neg_x);
  740. }
  741. #elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
  742. #if defined(__FMA__)
  743. #define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
  744. #define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
  745. #else
  746. #define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
  747. #define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
  748. #endif
  749. // adapted from arm limited optimized routine
  750. // the maximum error is 1.45358 plus 0.5 ulps
  751. // numbers above 88.38 will flush to infinity
  752. // numbers beneath -103.97 will flush to zero
  753. inline static __m128 ggml_v_expf(__m128 x) {
  754. const __m128 r = _mm_set1_ps(0x1.8p23f);
  755. const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
  756. const __m128 n = _mm_sub_ps(z, r);
  757. const __m128 b =
  758. NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
  759. const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
  760. const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
  761. const __m128i c =
  762. _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
  763. const __m128 u = _mm_mul_ps(b, b);
  764. const __m128 j =
  765. MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
  766. MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
  767. u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
  768. if (!_mm_movemask_epi8(c))
  769. return MADD128(j, k, k);
  770. const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
  771. _mm_set1_epi32(0x82000000u));
  772. const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
  773. const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
  774. const __m128i d =
  775. _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
  776. return _mm_or_ps(
  777. _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
  778. _mm_andnot_ps(_mm_castsi128_ps(d),
  779. _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
  780. _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
  781. }
  782. // computes silu x/(1+exp(-x)) in single precision vector
  783. inline static __m128 ggml_v_silu(__m128 x) {
  784. const __m128 one = _mm_set1_ps(1);
  785. const __m128 zero = _mm_setzero_ps();
  786. const __m128 neg_x = _mm_sub_ps(zero, x);
  787. const __m128 exp_neg_x = ggml_v_expf(neg_x);
  788. const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
  789. return _mm_div_ps(x, one_plus_exp_neg_x);
  790. }
  791. #endif // __ARM_NEON / __AVX2__ / __SSE2__
  792. inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  793. for (int i = 0; i < n; ++i) {
  794. y[i] = ggml_silu_f16(x[i]);
  795. }
  796. }
  797. inline static float ggml_silu_backward_f32(float x, float dy) {
  798. const float s = 1.0f/(1.0f + expf(-x));
  799. return dy*s*(1.0f + x*(1.0f - s));
  800. }
  801. inline static ggml_fp16_t ggml_silu_backward_f16(ggml_fp16_t x, ggml_fp16_t dy) {
  802. const float v = GGML_CPU_FP16_TO_FP32(x);
  803. const float s = 1.0f/(1.0f + expf(-v));
  804. return GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(dy)*s*(1.0f + v*(1.0f - s)));
  805. }
  806. inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
  807. for (int i = 0; i < n; ++i) {
  808. dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
  809. }
  810. }
  811. inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, const ggml_fp16_t * x, const ggml_fp16_t * dy) {
  812. for (int i = 0; i < n; ++i) {
  813. dx[i] = ggml_silu_backward_f16(x[i], dy[i]);
  814. }
  815. }
  816. inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x, const float * g) {
  817. for (int i = 0; i < n; ++i) {
  818. y[i] = (x[i] > 0.f) ? x[i] * g[i] : 0.f;
  819. }
  820. }
  821. inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
  822. for (int i = 0; i < n; ++i) {
  823. float v = GGML_CPU_FP16_TO_FP32(x[i]);
  824. y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v * GGML_CPU_FP16_TO_FP32(g[i]) : 0.f);
  825. }
  826. }
  827. #ifdef GGML_GELU_FP16
  828. inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
  829. uint16_t t;
  830. for (int i = 0; i < n; ++i) {
  831. if (x[i] <= -10.0f) {
  832. y[i] = 0.0f;
  833. } else if (x[i] >= 10.0f) {
  834. y[i] = x[i] * g[i];
  835. } else {
  836. ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
  837. memcpy(&t, &fp16, sizeof(uint16_t));
  838. y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[t]) * g[i];
  839. }
  840. }
  841. }
  842. #else
  843. inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
  844. for (int i = 0; i < n; ++i) {
  845. y[i] = ggml_gelu_f32(x[i]) * g[i];
  846. }
  847. }
  848. #endif
  849. inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
  850. const uint16_t * i16 = (const uint16_t *) x;
  851. for (int i = 0; i < n; ++i) {
  852. float v = GGML_CPU_FP16_TO_FP32(g[i]);
  853. y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * v);
  854. }
  855. }
  856. void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g);
  857. inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
  858. for (int i = 0; i < n; ++i) {
  859. float v = GGML_CPU_FP16_TO_FP32(x[i]);
  860. float w = GGML_CPU_FP16_TO_FP32(g[i]);
  861. y[i] = GGML_CPU_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
  862. }
  863. }
  864. inline static void ggml_vec_geglu_erf_f32(const int n, float * y, const float * x, const float * g) {
  865. for (int i = 0; i < n; ++i) {
  866. float xi = x[i];
  867. y[i] = 0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * g[i];
  868. }
  869. }
  870. inline static void ggml_vec_geglu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
  871. for (int i = 0; i < n; ++i) {
  872. float xi = GGML_CPU_FP16_TO_FP32(x[i]);
  873. float gi = GGML_CPU_FP16_TO_FP32(g[i]);
  874. y[i] = GGML_CPU_FP32_TO_FP16(0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * gi);
  875. }
  876. }
  877. #ifdef GGML_GELU_QUICK_FP16
  878. inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
  879. uint16_t t;
  880. for (int i = 0; i < n; ++i) {
  881. ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
  882. memcpy(&t, &fp16, sizeof(uint16_t));
  883. y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]) * g[i];
  884. }
  885. }
  886. #else
  887. inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
  888. for (int i = 0; i < n; ++i) {
  889. y[i] = ggml_gelu_quick_f32(x[i]) * g[i];
  890. }
  891. }
  892. #endif
  893. inline static void ggml_vec_geglu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
  894. const uint16_t * i16 = (const uint16_t *) x;
  895. for (int i = 0; i < n; ++i) {
  896. float v = GGML_CPU_FP16_TO_FP32(g[i]);
  897. y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[i16[i]]) * v);
  898. }
  899. }
  900. inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
  901. #ifndef GGML_USE_ACCELERATE
  902. ggml_float sum = 0.0;
  903. for (int i = 0; i < n; ++i) {
  904. sum += (ggml_float)x[i];
  905. }
  906. *s = (float)sum;
  907. #else
  908. vDSP_sve(x, 1, s, n);
  909. #endif
  910. }
  911. inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
  912. ggml_float sum = 0.0;
  913. for (int i = 0; i < n; ++i) {
  914. sum += (ggml_float)x[i];
  915. }
  916. *s = sum;
  917. }
  918. inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) {
  919. float sum = 0.0f;
  920. for (int i = 0; i < n; ++i) {
  921. sum += GGML_CPU_FP16_TO_FP32(x[i]);
  922. }
  923. *s = sum;
  924. }
  925. inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) {
  926. float sum = 0.0f;
  927. for (int i = 0; i < n; ++i) {
  928. sum += GGML_BF16_TO_FP32(x[i]);
  929. }
  930. *s = sum;
  931. }
  932. inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
  933. #ifndef GGML_USE_ACCELERATE
  934. float max = -INFINITY;
  935. for (int i = 0; i < n; ++i) {
  936. max = MAX(max, x[i]);
  937. }
  938. *s = max;
  939. #else
  940. vDSP_maxv(x, 1, s, n);
  941. #endif
  942. }
  943. inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) {
  944. ggml_vec_norm_f32(n, s, x);
  945. *s = 1.f/(*s);
  946. }
  947. inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {
  948. float max = -INFINITY;
  949. int idx = 0;
  950. for (int i = 0; i < n; ++i) {
  951. max = MAX(max, x[i]);
  952. if (max == x[i]) { idx = i; }
  953. }
  954. *s = idx;
  955. }
  956. #ifdef __cplusplus
  957. }
  958. #endif