vec.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. #include "vec.h"
  2. #include <cassert>
  3. // precomputed gelu table for f16 (128 KB)
  4. ggml_fp16_t ggml_table_gelu_f16[1 << 16];
  5. // precomputed quick gelu table for f16 (128 KB)
  6. ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
  7. void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc) {
  8. assert(nrc == 1);
  9. GGML_UNUSED(nrc);
  10. GGML_UNUSED(bx);
  11. GGML_UNUSED(by);
  12. GGML_UNUSED(bs);
  13. #if defined(GGML_SIMD)
  14. float sumf = 0.0f;
  15. #if defined(__ARM_FEATURE_SVE)
  16. const int sve_register_length = ggml_cpu_get_sve_cnt() * 8;
  17. const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16
  18. const int ggml_f32_step = 8 * ggml_f32_epr; // choose 8 SVE registers
  19. const int np = (n & ~(ggml_f32_step - 1));
  20. svfloat32_t sum1 = svdup_n_f32(0.0f);
  21. svfloat32_t sum2 = svdup_n_f32(0.0f);
  22. svfloat32_t sum3 = svdup_n_f32(0.0f);
  23. svfloat32_t sum4 = svdup_n_f32(0.0f);
  24. svfloat32_t sum5 = svdup_n_f32(0.0f);
  25. svfloat32_t sum6 = svdup_n_f32(0.0f);
  26. svfloat32_t sum7 = svdup_n_f32(0.0f);
  27. svfloat32_t sum8 = svdup_n_f32(0.0f);
  28. svfloat32_t ax1,ax2,ax3,ax4,ax5,ax6,ax7,ax8;
  29. svfloat32_t ay1,ay2,ay3,ay4,ay5,ay6,ay7,ay8;
  30. for (int i = 0; i < np; i += ggml_f32_step) {
  31. ax1 = GGML_F32_VEC_LOAD(x + i);
  32. ay1 = GGML_F32_VEC_LOAD(y + i);
  33. sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
  34. ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
  35. ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
  36. sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2);
  37. ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
  38. ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
  39. sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3);
  40. ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
  41. ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
  42. sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4);
  43. ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
  44. ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
  45. sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5);
  46. ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
  47. ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
  48. sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6);
  49. ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
  50. ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
  51. sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7);
  52. ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
  53. ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
  54. sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8);
  55. }
  56. // leftovers
  57. // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
  58. const int np2 = (n & ~(ggml_f32_epr - 1));
  59. for (int i = np; i < np2; i += ggml_f32_epr) {
  60. ax1 = GGML_F32_VEC_LOAD(x + i);
  61. ay1 = GGML_F32_VEC_LOAD(y + i);
  62. sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
  63. }
  64. // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
  65. if (np2 < n) {
  66. svbool_t pg = svwhilelt_b32(np2, n);
  67. ax1 = svld1_f32(pg, x + np2);
  68. ay1 = svld1_f32(pg, y + np2);
  69. sum1 = svmad_f32_m(pg, ax1, ay1, sum1);
  70. }
  71. // reduce sum1,sum2 to sum1
  72. GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8);
  73. #else
  74. const int np = (n & ~(GGML_F32_STEP - 1));
  75. GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
  76. GGML_F32_VEC ax[GGML_F32_ARR];
  77. GGML_F32_VEC ay[GGML_F32_ARR];
  78. for (int i = 0; i < np; i += GGML_F32_STEP) {
  79. for (int j = 0; j < GGML_F32_ARR; j++) {
  80. ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
  81. ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
  82. sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
  83. }
  84. }
  85. // reduce sum0..sum3 to sum0
  86. GGML_F32_VEC_REDUCE(sumf, sum);
  87. // leftovers
  88. for (int i = np; i < n; ++i) {
  89. sumf += x[i]*y[i];
  90. }
  91. #endif
  92. #else
  93. // scalar
  94. ggml_float sumf = 0.0;
  95. for (int i = 0; i < n; ++i) {
  96. sumf += (ggml_float)(x[i]*y[i]);
  97. }
  98. #endif
  99. *s = sumf;
  100. }
  101. void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc) {
  102. assert(nrc == 1);
  103. GGML_UNUSED(nrc);
  104. GGML_UNUSED(bx);
  105. GGML_UNUSED(by);
  106. GGML_UNUSED(bs);
  107. int i = 0;
  108. ggml_float sumf = 0;
  109. #if defined(__AVX512BF16__)
  110. __m512 c1 = _mm512_setzero_ps();
  111. __m512 c2 = _mm512_setzero_ps();
  112. for (; i + 64 <= n; i += 64) {
  113. c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
  114. m512bh(_mm512_loadu_si512((y + i))));
  115. c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
  116. m512bh(_mm512_loadu_si512((y + i + 32))));
  117. }
  118. sumf += (ggml_float)_mm512_reduce_add_ps(c1);
  119. sumf += (ggml_float)_mm512_reduce_add_ps(c2);
  120. #elif defined(__AVX512F__)
  121. #define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
  122. __m512 c1 = _mm512_setzero_ps();
  123. __m512 c2 = _mm512_setzero_ps();
  124. for (; i + 32 <= n; i += 32) {
  125. c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
  126. c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
  127. }
  128. sumf += (ggml_float)_mm512_reduce_add_ps(c1);
  129. sumf += (ggml_float)_mm512_reduce_add_ps(c2);
  130. #undef LOAD
  131. #elif defined(__AVX2__) || defined(__AVX__)
  132. #if defined(__AVX2__)
  133. #define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
  134. #else
  135. #define LOAD(p) _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)), (_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_bsrli_si128(_mm_loadu_si128((const __m128i *)(p)), 8)), 16)), 1))
  136. #endif
  137. __m256 c1 = _mm256_setzero_ps();
  138. __m256 c2 = _mm256_setzero_ps();
  139. __m256 c3 = _mm256_setzero_ps();
  140. __m256 c4 = _mm256_setzero_ps();
  141. for (; i + 32 <= n; i += 32) {
  142. c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
  143. c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);
  144. c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);
  145. c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);
  146. }
  147. __m128 g;
  148. c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),
  149. _mm256_add_ps(c2, c4));
  150. g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
  151. _mm256_castps256_ps128(c1));
  152. g = _mm_add_ps(g, _mm_movehl_ps(g, g));
  153. g = _mm_add_ss(g, _mm_movehdup_ps(g));
  154. sumf += (ggml_float)_mm_cvtss_f32(g);
  155. #undef LOAD
  156. #endif
  157. for (; i < n; ++i) {
  158. sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
  159. GGML_BF16_TO_FP32(y[i]));
  160. }
  161. *s = sumf;
  162. }
  163. void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc) {
  164. assert(nrc == 1);
  165. GGML_UNUSED(nrc);
  166. GGML_UNUSED(bx);
  167. GGML_UNUSED(by);
  168. GGML_UNUSED(bs);
  169. ggml_float sumf = 0.0;
  170. #if defined(GGML_SIMD)
  171. const int np = (n & ~(GGML_F16_STEP - 1));
  172. GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
  173. GGML_F16_VEC ax[GGML_F16_ARR];
  174. GGML_F16_VEC ay[GGML_F16_ARR];
  175. for (int i = 0; i < np; i += GGML_F16_STEP) {
  176. for (int j = 0; j < GGML_F16_ARR; j++) {
  177. ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
  178. ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
  179. sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
  180. }
  181. }
  182. // reduce sum0..sum3 to sum0
  183. GGML_F16_VEC_REDUCE(sumf, sum);
  184. // leftovers
  185. for (int i = np; i < n; ++i) {
  186. sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
  187. }
  188. // if you hit this, you are likely running outside the FP range
  189. assert(!isnan(sumf) && !isinf(sumf));
  190. #else
  191. for (int i = 0; i < n; ++i) {
  192. sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
  193. }
  194. #endif
  195. *s = sumf;
  196. }
  197. void ggml_vec_silu_f32(const int n, float * y, const float * x) {
  198. int i = 0;
  199. #if defined(__AVX512F__) && defined(__AVX512DQ__)
  200. for (; i + 15 < n; i += 16) {
  201. _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));
  202. }
  203. #elif defined(__AVX2__) && defined(__FMA__)
  204. for (; i + 7 < n; i += 8) {
  205. _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));
  206. }
  207. #elif defined(__SSE2__)
  208. for (; i + 3 < n; i += 4) {
  209. _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
  210. }
  211. #elif defined(__ARM_NEON) && defined(__aarch64__)
  212. for (; i + 3 < n; i += 4) {
  213. vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
  214. }
  215. #endif
  216. for (; i < n; ++i) {
  217. y[i] = ggml_silu_f32(x[i]);
  218. }
  219. }
  220. void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g) {
  221. int i = 0;
  222. #if defined(__AVX512F__) && defined(__AVX512DQ__)
  223. for (; i + 15 < n; i += 16) {
  224. _mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(g + i)));
  225. }
  226. #elif defined(__AVX2__) && defined(__FMA__)
  227. for (; i + 7 < n; i += 8) {
  228. _mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(g + i)));
  229. }
  230. #elif defined(__SSE2__)
  231. for (; i + 3 < n; i += 4) {
  232. _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i)));
  233. }
  234. #elif defined(__ARM_NEON) && defined(__aarch64__)
  235. for (; i + 3 < n; i += 4) {
  236. vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i)));
  237. }
  238. #endif
  239. for (; i < n; ++i) {
  240. y[i] = ggml_silu_f32(x[i]) * g[i];
  241. }
  242. }
  243. ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
  244. int i = 0;
  245. ggml_float sum = 0;
  246. #if defined(__AVX512F__) && defined(__AVX512DQ__)
  247. for (; i + 15 < n; i += 16) {
  248. __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
  249. _mm512_set1_ps(max)));
  250. _mm512_storeu_ps(y + i, val);
  251. sum += (ggml_float)_mm512_reduce_add_ps(val);
  252. }
  253. #elif defined(__AVX2__) && defined(__FMA__)
  254. for (; i + 7 < n; i += 8) {
  255. __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
  256. _mm256_set1_ps(max)));
  257. _mm256_storeu_ps(y + i, val);
  258. __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
  259. _mm256_castps256_ps128(val));
  260. val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
  261. val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
  262. sum += (ggml_float)_mm_cvtss_f32(val2);
  263. }
  264. #elif defined(__SSE2__)
  265. for (; i + 3 < n; i += 4) {
  266. __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
  267. _mm_set1_ps(max)));
  268. _mm_storeu_ps(y + i, val);
  269. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  270. val = _mm_add_ps(val, _mm_movehl_ps(val, val));
  271. val = _mm_add_ss(val, _mm_movehdup_ps(val));
  272. #else
  273. __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
  274. val = _mm_add_ps(val, tmp);
  275. tmp = _mm_movehl_ps(tmp, val);
  276. val = _mm_add_ss(val, tmp);
  277. #endif
  278. sum += (ggml_float)_mm_cvtss_f32(val);
  279. }
  280. #elif defined(__ARM_NEON) && defined(__aarch64__)
  281. for (; i + 3 < n; i += 4) {
  282. float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
  283. vdupq_n_f32(max)));
  284. vst1q_f32(y + i, val);
  285. sum += (ggml_float)vaddvq_f32(val);
  286. }
  287. #endif
  288. for (; i < n; ++i) {
  289. float val = expf(x[i] - max);
  290. sum += (ggml_float)val;
  291. y[i] = val;
  292. }
  293. return sum;
  294. }
  295. ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) {
  296. // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i)
  297. int i = 0;
  298. ggml_float sum = 0;
  299. for (; i < n; ++i) {
  300. float val = x[i] - max;
  301. y[i] = val;
  302. sum += (ggml_float)expf(val);
  303. }
  304. return sum = (ggml_float)logf(sum);
  305. }