vec.cpp 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. #include "vec.h"
  2. #include <cassert>
  3. #if defined(_MSC_VER)
  4. // disable "possible loss of data" to avoid hundreds of casts
  5. // we should just be careful :)
  6. #pragma warning(disable: 4244 4267)
  7. #endif
  8. // precomputed gelu table for f16 (128 KB)
  9. ggml_fp16_t ggml_table_gelu_f16[1 << 16];
  10. // precomputed quick gelu table for f16 (128 KB)
  11. ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
  12. void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc) {
  13. assert(nrc == 1);
  14. GGML_UNUSED(nrc);
  15. GGML_UNUSED(bx);
  16. GGML_UNUSED(by);
  17. GGML_UNUSED(bs);
  18. #if defined(GGML_SIMD)
  19. float sumf = 0.0f;
  20. const int np = (n & ~(GGML_F32_STEP - 1));
  21. GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
  22. GGML_F32_VEC ax[GGML_F32_ARR];
  23. GGML_F32_VEC ay[GGML_F32_ARR];
  24. for (int i = 0; i < np; i += GGML_F32_STEP) {
  25. for (int j = 0; j < GGML_F32_ARR; j++) {
  26. ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
  27. ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
  28. sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
  29. }
  30. }
  31. // reduce sum0..sum3 to sum0
  32. GGML_F32_VEC_REDUCE(sumf, sum);
  33. // leftovers
  34. for (int i = np; i < n; ++i) {
  35. sumf += x[i]*y[i];
  36. }
  37. #else
  38. // scalar
  39. ggml_float sumf = 0.0;
  40. for (int i = 0; i < n; ++i) {
  41. sumf += (ggml_float)(x[i]*y[i]);
  42. }
  43. #endif
  44. *s = sumf;
  45. }
  46. void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc) {
  47. assert(nrc == 1);
  48. GGML_UNUSED(nrc);
  49. GGML_UNUSED(bx);
  50. GGML_UNUSED(by);
  51. GGML_UNUSED(bs);
  52. int i = 0;
  53. ggml_float sumf = 0;
  54. #if defined(__AVX512BF16__)
  55. __m512 c1 = _mm512_setzero_ps();
  56. __m512 c2 = _mm512_setzero_ps();
  57. for (; i + 64 <= n; i += 64) {
  58. c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
  59. m512bh(_mm512_loadu_si512((y + i))));
  60. c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
  61. m512bh(_mm512_loadu_si512((y + i + 32))));
  62. }
  63. sumf += (ggml_float)_mm512_reduce_add_ps(c1);
  64. sumf += (ggml_float)_mm512_reduce_add_ps(c2);
  65. #elif defined(__AVX512F__)
  66. #define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
  67. __m512 c1 = _mm512_setzero_ps();
  68. __m512 c2 = _mm512_setzero_ps();
  69. for (; i + 32 <= n; i += 32) {
  70. c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
  71. c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
  72. }
  73. sumf += (ggml_float)_mm512_reduce_add_ps(c1);
  74. sumf += (ggml_float)_mm512_reduce_add_ps(c2);
  75. #undef LOAD
  76. #elif defined(__AVX2__) || defined(__AVX__)
  77. #if defined(__AVX2__)
  78. #define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
  79. #else
  80. #define LOAD(p) _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)), (_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_bsrli_si128(_mm_loadu_si128((const __m128i *)(p)), 8)), 16)), 1))
  81. #endif
  82. __m256 c1 = _mm256_setzero_ps();
  83. __m256 c2 = _mm256_setzero_ps();
  84. __m256 c3 = _mm256_setzero_ps();
  85. __m256 c4 = _mm256_setzero_ps();
  86. for (; i + 32 <= n; i += 32) {
  87. c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
  88. c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);
  89. c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);
  90. c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);
  91. }
  92. __m128 g;
  93. c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),
  94. _mm256_add_ps(c2, c4));
  95. g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
  96. _mm256_castps256_ps128(c1));
  97. g = _mm_add_ps(g, _mm_movehl_ps(g, g));
  98. g = _mm_add_ss(g, _mm_movehdup_ps(g));
  99. sumf += (ggml_float)_mm_cvtss_f32(g);
  100. #undef LOAD
  101. #endif
  102. for (; i < n; ++i) {
  103. sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
  104. GGML_BF16_TO_FP32(y[i]));
  105. }
  106. *s = sumf;
  107. }
  108. void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc) {
  109. assert(nrc == 1);
  110. GGML_UNUSED(nrc);
  111. GGML_UNUSED(bx);
  112. GGML_UNUSED(by);
  113. GGML_UNUSED(bs);
  114. ggml_float sumf = 0.0;
  115. #if defined(GGML_SIMD)
  116. const int np = (n & ~(GGML_F16_STEP - 1));
  117. GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
  118. GGML_F16_VEC ax[GGML_F16_ARR];
  119. GGML_F16_VEC ay[GGML_F16_ARR];
  120. for (int i = 0; i < np; i += GGML_F16_STEP) {
  121. for (int j = 0; j < GGML_F16_ARR; j++) {
  122. ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
  123. ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
  124. sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
  125. }
  126. }
  127. // reduce sum0..sum3 to sum0
  128. GGML_F16_VEC_REDUCE(sumf, sum);
  129. // leftovers
  130. for (int i = np; i < n; ++i) {
  131. sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
  132. }
  133. #else
  134. for (int i = 0; i < n; ++i) {
  135. sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
  136. }
  137. #endif
  138. *s = sumf;
  139. }
  140. void ggml_vec_silu_f32(const int n, float * y, const float * x) {
  141. int i = 0;
  142. #if defined(__AVX512F__) && defined(__AVX512DQ__)
  143. for (; i + 15 < n; i += 16) {
  144. _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));
  145. }
  146. #elif defined(__AVX2__) && defined(__FMA__)
  147. for (; i + 7 < n; i += 8) {
  148. _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));
  149. }
  150. #elif defined(__SSE2__)
  151. for (; i + 3 < n; i += 4) {
  152. _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
  153. }
  154. #elif defined(__ARM_NEON) && defined(__aarch64__)
  155. for (; i + 3 < n; i += 4) {
  156. vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
  157. }
  158. #endif
  159. for (; i < n; ++i) {
  160. y[i] = ggml_silu_f32(x[i]);
  161. }
  162. }
  163. ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
  164. int i = 0;
  165. ggml_float sum = 0;
  166. #if defined(__AVX512F__) && defined(__AVX512DQ__)
  167. for (; i + 15 < n; i += 16) {
  168. __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
  169. _mm512_set1_ps(max)));
  170. _mm512_storeu_ps(y + i, val);
  171. sum += (ggml_float)_mm512_reduce_add_ps(val);
  172. }
  173. #elif defined(__AVX2__) && defined(__FMA__)
  174. for (; i + 7 < n; i += 8) {
  175. __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
  176. _mm256_set1_ps(max)));
  177. _mm256_storeu_ps(y + i, val);
  178. __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
  179. _mm256_castps256_ps128(val));
  180. val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
  181. val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
  182. sum += (ggml_float)_mm_cvtss_f32(val2);
  183. }
  184. #elif defined(__SSE2__)
  185. for (; i + 3 < n; i += 4) {
  186. __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
  187. _mm_set1_ps(max)));
  188. _mm_storeu_ps(y + i, val);
  189. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  190. val = _mm_add_ps(val, _mm_movehl_ps(val, val));
  191. val = _mm_add_ss(val, _mm_movehdup_ps(val));
  192. #else
  193. __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
  194. val = _mm_add_ps(val, tmp);
  195. tmp = _mm_movehl_ps(tmp, val);
  196. val = _mm_add_ss(val, tmp);
  197. #endif
  198. sum += (ggml_float)_mm_cvtss_f32(val);
  199. }
  200. #elif defined(__ARM_NEON) && defined(__aarch64__)
  201. for (; i + 3 < n; i += 4) {
  202. float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
  203. vdupq_n_f32(max)));
  204. vst1q_f32(y + i, val);
  205. sum += (ggml_float)vaddvq_f32(val);
  206. }
  207. #endif
  208. for (; i < n; ++i) {
  209. float val = expf(x[i] - max);
  210. sum += (ggml_float)val;
  211. y[i] = val;
  212. }
  213. return sum;
  214. }
  215. ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) {
  216. // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i)
  217. int i = 0;
  218. ggml_float sum = 0;
  219. for (; i < n; ++i) {
  220. float val = x[i] - max;
  221. y[i] = val;
  222. sum += (ggml_float)expf(val);
  223. }
  224. return sum = (ggml_float)logf(sum);
  225. }