| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106 |
- // Vectorized functions for fundamental operations
- #pragma once
- #include "ggml-impl.h"
- #include "simd-mappings.h"
- #include "ggml.h"
- #include "ggml-cpu.h"
- #if defined(GGML_USE_ACCELERATE)
- #include <Accelerate/Accelerate.h>
- #endif
- // floating point type used to accumulate sums
- typedef double ggml_float;
- #define GGML_GELU_FP16
- #define GGML_GELU_QUICK_FP16
- #define GGML_SOFT_MAX_UNROLL 4
- #define GGML_VEC_DOT_UNROLL 2
- #define GGML_VEC_MAD_UNROLL 32
- #ifdef __cplusplus
- extern "C" {
- #endif
- //
- // global data
- //
- // precomputed gelu table for f16 (128 KB)
- extern ggml_fp16_t ggml_table_gelu_f16[1 << 16];
- // precomputed quick gelu table for f16 (128 KB)
- extern ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
- //
- // fundamental operations
- //
- void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
- void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc);
- void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
- void ggml_vec_silu_f32(const int n, float * y, const float * x);
- ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max);
- ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max);
- inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
- inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
- inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
- inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
- inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
- inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
- inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
- inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
- for (int i = 0; i < n; ++i) {
- z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) + GGML_CPU_FP16_TO_FP32(y[i]));
- }
- }
- inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
- inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
- inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
- inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; }
- inline static void ggml_vec_sub_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
- for (int i = 0; i < n; ++i) {
- z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) - GGML_CPU_FP16_TO_FP32(y[i]));
- }
- }
- inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
- inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
- inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
- inline static void ggml_vec_neg_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- for (int i = 0; i < n; ++i) {
- y[i] = GGML_CPU_FP32_TO_FP16(-GGML_CPU_FP16_TO_FP32(x[i]));
- }
- }
- inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
- inline static void ggml_vec_mul_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
- for (int i = 0; i < n; ++i) {
- z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) * GGML_CPU_FP16_TO_FP32(y[i]));
- }
- }
- inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
- inline static void ggml_vec_div_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
- for (int i = 0; i < n; ++i) {
- z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) / GGML_CPU_FP16_TO_FP32(y[i]));
- }
- }
- // compute GGML_VEC_DOT_UNROLL dot products at once
- // xs - x row stride in bytes
- inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GGML_RESTRICT s, void * GGML_RESTRICT xv, ggml_fp16_t * GGML_RESTRICT y) {
- ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
- ggml_fp16_t * GGML_RESTRICT x[GGML_VEC_DOT_UNROLL];
- for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
- x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
- }
- #if defined(GGML_SIMD)
- const int np = (n & ~(GGML_F16_STEP - 1));
- GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
- GGML_F16_VEC ax[GGML_F16_ARR];
- GGML_F16_VEC ay[GGML_F16_ARR];
- for (int i = 0; i < np; i += GGML_F16_STEP) {
- for (int j = 0; j < GGML_F16_ARR; j++) {
- ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
- for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
- ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
- sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
- }
- }
- }
- // reduce sum0..sum3 to sum0
- for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
- GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
- }
- // leftovers
- for (int i = np; i < n; ++i) {
- for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
- sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
- }
- }
- #else
- for (int i = 0; i < n; ++i) {
- for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
- sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
- }
- }
- #endif
- for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
- s[i] = (float)sumf[i];
- }
- }
- inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const float * GGML_RESTRICT x, const float v) {
- #if defined(GGML_SIMD)
- #if defined(__ARM_FEATURE_SVE)
- const int sve_register_length = ggml_cpu_get_sve_cnt() * 8;
- const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16
- const int ggml_f32_step = 8 * ggml_f32_epr; // choose 8 SVE registers
- GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
- const int np = (n & ~(ggml_f32_step - 1));
- svfloat32_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
- svfloat32_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
- for (int i = 0; i < np; i += ggml_f32_step) {
- ax1 = GGML_F32_VEC_LOAD(x + i);
- ay1 = GGML_F32_VEC_LOAD(y + i);
- ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
- GGML_F32_VEC_STORE(y + i, ay1);
- ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
- ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
- ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx);
- GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
- ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
- ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
- ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx);
- GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);
- ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
- ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
- ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx);
- GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);
- ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
- ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
- ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx);
- GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);
- ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
- ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
- ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx);
- GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);
- ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
- ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
- ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx);
- GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);
- ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
- ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
- ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx);
- GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
- }
- // leftovers
- // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
- const int np2 = (n & ~(ggml_f32_epr - 1));
- for (int i = np; i < np2; i += ggml_f32_epr) {
- ax1 = GGML_F32_VEC_LOAD(x + i);
- ay1 = GGML_F32_VEC_LOAD(y + i);
- ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
- GGML_F32_VEC_STORE(y + i, ay1);
- }
- // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
- if (np2 < n) {
- svbool_t pg =svwhilelt_b32(np2, n);
- ax1 = svld1_f32(pg, x + np2);
- ay1 = svld1_f32(pg, y + np2);
- ay1 = svmad_f32_m(pg, ax1, vx, ay1);
- svst1_f32(pg, y + np2, ay1);
- }
- #else
- const int np = (n & ~(GGML_F32_STEP - 1));
- GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
- GGML_F32_VEC ax[GGML_F32_ARR];
- GGML_F32_VEC ay[GGML_F32_ARR];
- for (int i = 0; i < np; i += GGML_F32_STEP) {
- for (int j = 0; j < GGML_F32_ARR; j++) {
- ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
- ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
- ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
- GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
- }
- }
- // leftovers
- for (int i = np; i < n; ++i) {
- y[i] += x[i]*v;
- }
- #endif
- #else
- // scalar
- for (int i = 0; i < n; ++i) {
- y[i] += x[i]*v;
- }
- #endif
- }
- inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
- #if defined(GGML_SIMD)
- const int np = (n & ~(GGML_F16_STEP - 1));
- GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
- GGML_F16_VEC ax[GGML_F16_ARR];
- GGML_F16_VEC ay[GGML_F16_ARR];
- for (int i = 0; i < np; i += GGML_F16_STEP) {
- for (int j = 0; j < GGML_F16_ARR; j++) {
- ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
- ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
- ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
- GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
- }
- }
- // leftovers
- for (int i = np; i < n; ++i) {
- y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
- }
- #else
- // scalar
- for (int i = 0; i < n; ++i) {
- y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
- }
- #endif
- }
- // xs and vs are byte strides of x and v
- inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * GGML_RESTRICT y, const float * GGML_RESTRICT xv, const float * GGML_RESTRICT vv) {
- const float * GGML_RESTRICT x[GGML_VEC_MAD_UNROLL];
- const float * GGML_RESTRICT v[GGML_VEC_MAD_UNROLL];
- for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) {
- x[i] = (const float *) ((const char *) xv + i*xs);
- v[i] = (const float *) ((const char *) vv + i*vs);
- }
- #if defined(GGML_SIMD)
- #if defined(__ARM_FEATURE_SVE)
- // scalar Route to scalar implementation //TODO: Write SVE code
- for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
- for (int i = 0; i < n; ++i) {
- y[i] += x[k][i]*v[k][0];
- }
- }
- #else
- const int np = (n & ~(GGML_F32_STEP - 1));
- GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
- for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
- vx[k] = GGML_F32_VEC_SET1(v[k][0]);
- }
- GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
- GGML_F32_VEC ay[GGML_F32_ARR];
- for (int i = 0; i < np; i += GGML_F32_STEP) {
- for (int j = 0; j < GGML_F32_ARR; j++) {
- ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
- for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
- ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
- ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
- }
- GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
- }
- }
- // leftovers
- for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
- for (int i = np; i < n; ++i) {
- y[i] += x[k][i]*v[k][0];
- }
- }
- #endif
- #else
- // scalar
- for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
- for (int i = 0; i < n; ++i) {
- y[i] += x[k][i]*v[k][0];
- }
- }
- #endif
- }
- inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) {
- #if defined(GGML_USE_ACCELERATE)
- vDSP_vsmsa(x, 1, &s, &b, y, 1, n);
- #elif defined(GGML_SIMD)
- #if defined(__ARM_FEATURE_SVE)
- // scalar ; TODO: Write SVE code
- for (int i = 0; i < n; ++i) {
- y[i] = x[i]*s + b;
- }
- #else
- const int np = (n & ~(GGML_F32_STEP - 1));
- GGML_F32_VEC vs = GGML_F32_VEC_SET1(s);
- GGML_F32_VEC vb = GGML_F32_VEC_SET1(b);
- GGML_F32_VEC ay[GGML_F32_ARR];
- for (int i = 0; i < np; i += GGML_F32_STEP) {
- for (int j = 0; j < GGML_F32_ARR; j++) {
- ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
- ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb);
- GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
- }
- }
- // leftovers
- for (int i = np; i < n; ++i) {
- y[i] = x[i]*s + b;
- }
- #endif
- #else
- // scalar
- for (int i = 0; i < n; ++i) {
- y[i] = x[i]*s + b;
- }
- #endif
- }
- //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
- inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
- #if defined(GGML_USE_ACCELERATE)
- vDSP_vsmul(y, 1, &v, y, 1, n);
- #elif defined(GGML_SIMD)
- #if defined(__ARM_FEATURE_SVE)
- const int sve_register_length = ggml_cpu_get_sve_cnt() * 8;
- const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16
- const int ggml_f32_step = 2 * ggml_f32_epr;
- GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
- const int np = (n & ~(ggml_f32_step - 1));
- svfloat32_t ay1;
- svfloat32_t ay2;
- for (int i = 0; i < np; i += ggml_f32_step) {
- ay1 = GGML_F32_VEC_LOAD(y + i);
- ay1 = GGML_F32_VEC_MUL(ay1, vx);
- GGML_F32_VEC_STORE(y + i, ay1);
- ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
- ay2 = GGML_F32_VEC_MUL(ay2, vx);
- GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
- }
- // leftovers
- // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
- if (np < n) {
- svbool_t pg = svwhilelt_b32(np, n);
- ay1 = svld1_f32(pg, y + np);
- ay1 = svmul_f32_m(pg, ay1, vx);
- svst1_f32(pg, y + np, ay1);
- }
- #else
- const int np = (n & ~(GGML_F32_STEP - 1));
- GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
- GGML_F32_VEC ay[GGML_F32_ARR];
- for (int i = 0; i < np; i += GGML_F32_STEP) {
- for (int j = 0; j < GGML_F32_ARR; j++) {
- ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
- ay[j] = GGML_F32_VEC_MUL(ay[j], vx);
- GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
- }
- }
- // leftovers
- for (int i = np; i < n; ++i) {
- y[i] *= v;
- }
- #endif
- #else
- // scalar
- for (int i = 0; i < n; ++i) {
- y[i] *= v;
- }
- #endif
- }
- inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
- #if defined(GGML_SIMD)
- const int np = (n & ~(GGML_F16_STEP - 1));
- GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
- GGML_F16_VEC ay[GGML_F16_ARR];
- for (int i = 0; i < np; i += GGML_F16_STEP) {
- for (int j = 0; j < GGML_F16_ARR; j++) {
- ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
- ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
- GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
- }
- }
- // leftovers
- for (int i = np; i < n; ++i) {
- y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
- }
- #else
- // scalar
- for (int i = 0; i < n; ++i) {
- y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
- }
- #endif
- }
- inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
- inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
- inline static void ggml_vec_sqr_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- for (int i = 0; i < n; ++i) {
- float v = GGML_CPU_FP16_TO_FP32(x[i]);
- y[i] = GGML_CPU_FP32_TO_FP16(v*v);
- }
- }
- inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
- inline static void ggml_vec_sqrt_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- for (int i = 0; i < n; ++i) {
- y[i] = GGML_CPU_FP32_TO_FP16(sqrtf(GGML_CPU_FP16_TO_FP32(x[i])));
- }
- }
- inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
- inline static void ggml_vec_log_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- for (int i = 0; i < n; ++i) {
- y[i] = GGML_CPU_FP32_TO_FP16(logf(GGML_CPU_FP16_TO_FP32(x[i])));
- }
- }
- inline static void ggml_vec_sin_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]); }
- inline static void ggml_vec_sin_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- for (int i = 0; i < n; ++i) {
- y[i] = GGML_CPU_FP32_TO_FP16(sinf(GGML_CPU_FP16_TO_FP32(x[i])));
- }
- }
- inline static void ggml_vec_cos_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]); }
- inline static void ggml_vec_cos_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- for (int i = 0; i < n; ++i) {
- y[i] = GGML_CPU_FP32_TO_FP16(cosf(GGML_CPU_FP16_TO_FP32(x[i])));
- }
- }
- inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
- inline static void ggml_vec_abs_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- for (int i = 0; i < n; ++i) {
- y[i] = GGML_CPU_FP32_TO_FP16(fabsf(GGML_CPU_FP16_TO_FP32(x[i])));
- }
- }
- inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
- inline static void ggml_vec_sgn_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- for (int i = 0; i < n; ++i) {
- float v = GGML_CPU_FP16_TO_FP32(x[i]);
- y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? 1.f : ((v < 0.f) ? -1.f : 0.f));
- }
- }
- inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
- inline static void ggml_vec_step_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- for (int i = 0; i < n; ++i) {
- y[i] = GGML_CPU_FP32_TO_FP16((GGML_CPU_FP16_TO_FP32(x[i]) > 0.f) ? 1.f : 0.f);
- }
- }
- inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
- inline static void ggml_vec_tanh_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- for (int i = 0; i < n; ++i) {
- y[i] = GGML_CPU_FP32_TO_FP16(tanhf(GGML_CPU_FP16_TO_FP32(x[i])));
- }
- }
- inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }
- inline static void ggml_vec_elu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- for (int i = 0; i < n; ++i) {
- y[i] = GGML_CPU_FP32_TO_FP16(expm1f(GGML_CPU_FP16_TO_FP32(x[i])));
- }
- }
- inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
- inline static void ggml_vec_relu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- for (int i = 0; i < n; ++i) {
- float v = GGML_CPU_FP16_TO_FP32(x[i]);
- y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v : 0.f);
- }
- }
- inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
- inline static void ggml_vec_leaky_relu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const float ns) {
- for (int i = 0; i < n; ++i) {
- float v = GGML_CPU_FP16_TO_FP32(x[i]);
- y[i] = GGML_CPU_FP32_TO_FP16(((v > 0.f) ? v : 0.f) + ns * ((v < 0.0f) ? v : 0.f));
- }
- }
- inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
- inline static void ggml_vec_sigmoid_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- for (int i = 0; i < n; ++i) {
- y[i] = GGML_CPU_FP32_TO_FP16(1.f / (1.f + expf(-GGML_CPU_FP16_TO_FP32(x[i]))));
- }
- }
- // TODO: optimize performance
- inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
- inline static void ggml_vec_hardswish_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- for (int i = 0; i < n; ++i) {
- float v = GGML_CPU_FP16_TO_FP32(x[i]);
- y[i] = GGML_CPU_FP32_TO_FP16(v * fminf(1.0f, fmaxf(0.0f, (v + 3.0f) / 6.0f)));
- }
- }
- inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
- inline static void ggml_vec_hardsigmoid_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- for (int i = 0; i < n; ++i) {
- y[i] = GGML_CPU_FP32_TO_FP16(fminf(1.0f, fmaxf(0.0f, (GGML_CPU_FP16_TO_FP32(x[i]) + 3.0f) / 6.0f)));
- }
- }
- inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
- inline static void ggml_vec_exp_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- for (int i = 0; i < n; ++i) {
- y[i] = GGML_CPU_FP32_TO_FP16(expf(GGML_CPU_FP16_TO_FP32(x[i])));
- }
- }
- static const float GELU_COEF_A = 0.044715f;
- static const float GELU_QUICK_COEF = -1.702f;
- static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
- static const float SQRT_2_INV = 0.70710678118654752440084436210484f;
- inline static float ggml_gelu_f32(float x) {
- return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
- }
- inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- const uint16_t * i16 = (const uint16_t *) x;
- for (int i = 0; i < n; ++i) {
- y[i] = ggml_table_gelu_f16[i16[i]];
- }
- }
- inline static void ggml_vec_gelu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- for (int i = 0; i < n; ++i) {
- float xi = GGML_CPU_FP16_TO_FP32(x[i]);
- float res = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
- y[i] = GGML_CPU_FP32_TO_FP16(res);
- }
- }
- #ifdef GGML_GELU_FP16
- inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
- uint16_t t;
- for (int i = 0; i < n; ++i) {
- if (x[i] <= -10.0f) {
- y[i] = 0.0f;
- } else if (x[i] >= 10.0f) {
- y[i] = x[i];
- } else {
- ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
- memcpy(&t, &fp16, sizeof(uint16_t));
- y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[t]);
- }
- }
- }
- #else
- inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
- for (int i = 0; i < n; ++i) {
- y[i] = ggml_gelu_f32(x[i]);
- }
- }
- #endif
- inline static void ggml_vec_gelu_erf_f32(const int n, float * y, const float * x) {
- for (int i = 0; i < n; ++i) {
- float xi = x[i];
- y[i] = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
- }
- }
- inline static float ggml_gelu_quick_f32(float x) {
- return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
- }
- //inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- // const uint16_t * i16 = (const uint16_t *) x;
- // for (int i = 0; i < n; ++i) {
- // y[i] = ggml_table_gelu_quick_f16[i16[i]];
- // }
- //}
- #ifdef GGML_GELU_QUICK_FP16
- inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
- uint16_t t;
- for (int i = 0; i < n; ++i) {
- ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
- memcpy(&t, &fp16, sizeof(uint16_t));
- y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]);
- }
- }
- #else
- inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
- for (int i = 0; i < n; ++i) {
- y[i] = ggml_gelu_quick_f32(x[i]);
- }
- }
- #endif
- inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- for (int i = 0; i < n; ++i) {
- float v = GGML_CPU_FP16_TO_FP32(x[i]);
- y[i] = GGML_CPU_FP32_TO_FP16(v*(1.0f/(1.0f+expf(GELU_QUICK_COEF*v))));
- }
- }
- // Sigmoid Linear Unit (SiLU) function
- inline static float ggml_silu_f32(float x) {
- return x/(1.0f + expf(-x));
- }
- inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) {
- float v = GGML_CPU_FP16_TO_FP32(x);
- return GGML_CPU_FP32_TO_FP16(v/(1.0f + expf(-v)));
- }
- #if __FINITE_MATH_ONLY__
- #error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix"
- #error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461"
- #endif
- /* Below function was borrowed from the GitHub repository:
- https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp */
- #if defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
- inline static svfloat32_t exp_ps_sve(svbool_t pg, svfloat32_t src) {
- // Constants
- const svfloat32_t log2_e = svdup_n_f32(1.4426950409f);
- const svfloat32_t ln2 = svdup_n_f32(0.6931473921f);
- const svfloat32_t half_ln2_sq = svdup_n_f32(0.2413862043f);
- const svuint32_t not_mask17 = svdup_n_u32(~((1u << 17) - 1));
- const svfloat32_t one = svdup_n_f32(1.0f);
- const svfloat32_t inactive1 = svdup_n_f32(0.0f);
- const svint32_t inactive2 = svdup_n_s32(0);
- // Algorithm starts here
- svfloat32_t t0 = svmul_f32_m(pg, src, log2_e); // y = x * log2(e)
- svfloat32_t t1 = svrintm_f32_m(inactive1, pg, t0); // rount to int (float)
- svint32_t t2 = svcvt_s32_f32_m(inactive2, pg, t1); // n
- t1 = svsub_f32_m(pg, t0, t1); // a = y - floor(y)
- t1 = svadd_f32_m(pg, t1, one); // b = a + 1
- svuint32_t t3 = svlsr_n_u32_m(pg, svreinterpret_u32_f32(t1), 17); // v = b >> 17 (u32)
- svfloat32_t t4 = svexpa_f32(t3); // c = fexpa(v)
- t4 = svscale_f32_m(pg, t4, t2); // fexpa(v) * 2^(n)
- // and_(t2.d, t1.d, not_mask17.d)
- svfloat32_t t5 = svreinterpret_f32_u32(svand_u32_m(pg, svreinterpret_u32_f32(t1), not_mask17));
- t5 = svsub_f32_m(pg, t1, t5); // z
- t0 = svmla_f32_m(pg, ln2, t5, half_ln2_sq); // ln2 + half_ln2_sq * z
- t0 = svmla_f32_m(pg, one, t5, t0); // 1 + (ln2 * z) + (half_ln2_sq * z * z)
- t0 = svmul_f32_m(pg, t0, t4); // Final result
- return t0;
- }
- #endif
- #if defined(__ARM_NEON) && defined(__aarch64__)
- // adapted from arm limited optimized routine
- // the maximum error is 1.45358 plus 0.5 ulps
- // numbers above 88.38 will flush to infinity
- // numbers beneath -103.97 will flush to zero
- inline static float32x4_t ggml_v_expf(float32x4_t x) {
- const float32x4_t r = vdupq_n_f32(0x1.8p23f);
- const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
- const float32x4_t n = vsubq_f32(z, r);
- const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
- vdupq_n_f32(0x1.7f7d1cp-20f));
- const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
- const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
- const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
- const float32x4_t u = vmulq_f32(b, b);
- const float32x4_t j = vfmaq_f32(
- vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
- vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
- vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
- if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
- return vfmaq_f32(k, j, k);
- const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
- const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
- const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
- return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
- vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
- }
- // computes silu x/(1+exp(-x)) in single precision vector
- inline static float32x4_t ggml_v_silu(float32x4_t x) {
- const float32x4_t one = vdupq_n_f32(1.0f);
- const float32x4_t zero = vdupq_n_f32(0.0f);
- const float32x4_t neg_x = vsubq_f32(zero, x);
- const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
- const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
- return vdivq_f32(x, one_plus_exp_neg_x);
- }
- #elif defined(__AVX512F__) && defined(__AVX512DQ__)
- // adapted from arm limited optimized routine
- // the maximum error is 1.45358 plus 0.5 ulps
- // numbers above 88.38 will flush to infinity
- // numbers beneath -103.97 will flush to zero
- inline static __m512 ggml_v_expf(__m512 x) {
- const __m512 r = _mm512_set1_ps(0x1.8p23f);
- const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
- const __m512 n = _mm512_sub_ps(z, r);
- const __m512 b =
- _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
- _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
- const __mmask16 d =
- _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
- const __m512 u = _mm512_mul_ps(b, b);
- const __m512 j = _mm512_fmadd_ps(
- _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
- _mm512_set1_ps(0x1.573e2ep-5f)),
- u,
- _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
- _mm512_set1_ps(0x1.fffdb6p-2f))),
- u,
- _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
- const __m512 res = _mm512_scalef_ps(j, n);
- if (_mm512_kortestz(d, d))
- return res;
- const __m512 zero = _mm512_setzero_ps();
- const __m512 alt = _mm512_mask_blend_ps(
- _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
- return _mm512_mask_blend_ps(d, res, alt);
- }
- // computes silu x/(1+exp(-x)) in single precision vector
- inline static __m512 ggml_v_silu(__m512 x) {
- const __m512 one = _mm512_set1_ps(1);
- const __m512 zero = _mm512_setzero_ps();
- const __m512 neg_x = _mm512_sub_ps(zero, x);
- const __m512 exp_neg_x = ggml_v_expf(neg_x);
- const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
- return _mm512_div_ps(x, one_plus_exp_neg_x);
- }
- #elif defined(__AVX2__) && defined(__FMA__)
- // adapted from arm limited optimized routine
- // the maximum error is 1.45358 plus 0.5 ulps
- // numbers above 88.38 will flush to infinity
- // numbers beneath -103.97 will flush to zero
- inline static __m256 ggml_v_expf(__m256 x) {
- const __m256 r = _mm256_set1_ps(0x1.8p23f);
- const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
- const __m256 n = _mm256_sub_ps(z, r);
- const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
- _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
- const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
- const __m256 k = _mm256_castsi256_ps(
- _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
- const __m256i c = _mm256_castps_si256(
- _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
- _mm256_set1_ps(126), _CMP_GT_OQ));
- const __m256 u = _mm256_mul_ps(b, b);
- const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
- _mm256_set1_ps(0x1.573e2ep-5f)), u,
- _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
- _mm256_set1_ps(0x1.fffdb6p-2f))),
- u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
- if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
- return _mm256_fmadd_ps(j, k, k);
- const __m256i g = _mm256_and_si256(
- _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
- _mm256_set1_epi32(0x82000000u));
- const __m256 s1 =
- _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
- const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
- const __m256i d = _mm256_castps_si256(
- _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
- _mm256_set1_ps(192), _CMP_GT_OQ));
- return _mm256_or_ps(
- _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
- _mm256_andnot_ps(
- _mm256_castsi256_ps(d),
- _mm256_or_ps(
- _mm256_and_ps(_mm256_castsi256_ps(c),
- _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
- _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
- }
- // computes silu x/(1+exp(-x)) in single precision vector
- inline static __m256 ggml_v_silu(__m256 x) {
- const __m256 one = _mm256_set1_ps(1);
- const __m256 zero = _mm256_setzero_ps();
- const __m256 neg_x = _mm256_sub_ps(zero, x);
- const __m256 exp_neg_x = ggml_v_expf(neg_x);
- const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
- return _mm256_div_ps(x, one_plus_exp_neg_x);
- }
- #elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
- #if defined(__FMA__)
- #define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
- #define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
- #else
- #define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
- #define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
- #endif
- // adapted from arm limited optimized routine
- // the maximum error is 1.45358 plus 0.5 ulps
- // numbers above 88.38 will flush to infinity
- // numbers beneath -103.97 will flush to zero
- inline static __m128 ggml_v_expf(__m128 x) {
- const __m128 r = _mm_set1_ps(0x1.8p23f);
- const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
- const __m128 n = _mm_sub_ps(z, r);
- const __m128 b =
- NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
- const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
- const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
- const __m128i c =
- _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
- const __m128 u = _mm_mul_ps(b, b);
- const __m128 j =
- MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
- MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
- u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
- if (!_mm_movemask_epi8(c))
- return MADD128(j, k, k);
- const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
- _mm_set1_epi32(0x82000000u));
- const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
- const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
- const __m128i d =
- _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
- return _mm_or_ps(
- _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
- _mm_andnot_ps(_mm_castsi128_ps(d),
- _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
- _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
- }
- // computes silu x/(1+exp(-x)) in single precision vector
- inline static __m128 ggml_v_silu(__m128 x) {
- const __m128 one = _mm_set1_ps(1);
- const __m128 zero = _mm_setzero_ps();
- const __m128 neg_x = _mm_sub_ps(zero, x);
- const __m128 exp_neg_x = ggml_v_expf(neg_x);
- const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
- return _mm_div_ps(x, one_plus_exp_neg_x);
- }
- #endif // __ARM_NEON / __AVX2__ / __SSE2__
- inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
- for (int i = 0; i < n; ++i) {
- y[i] = ggml_silu_f16(x[i]);
- }
- }
- inline static float ggml_silu_backward_f32(float x, float dy) {
- const float s = 1.0f/(1.0f + expf(-x));
- return dy*s*(1.0f + x*(1.0f - s));
- }
- inline static ggml_fp16_t ggml_silu_backward_f16(ggml_fp16_t x, ggml_fp16_t dy) {
- const float v = GGML_CPU_FP16_TO_FP32(x);
- const float s = 1.0f/(1.0f + expf(-v));
- return GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(dy)*s*(1.0f + v*(1.0f - s)));
- }
- inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
- for (int i = 0; i < n; ++i) {
- dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
- }
- }
- inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, const ggml_fp16_t * x, const ggml_fp16_t * dy) {
- for (int i = 0; i < n; ++i) {
- dx[i] = ggml_silu_backward_f16(x[i], dy[i]);
- }
- }
- inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x, const float * g) {
- for (int i = 0; i < n; ++i) {
- y[i] = (x[i] > 0.f) ? x[i] * g[i] : 0.f;
- }
- }
- inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
- for (int i = 0; i < n; ++i) {
- float v = GGML_CPU_FP16_TO_FP32(x[i]);
- y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v * GGML_CPU_FP16_TO_FP32(g[i]) : 0.f);
- }
- }
- #ifdef GGML_GELU_FP16
- inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
- uint16_t t;
- for (int i = 0; i < n; ++i) {
- if (x[i] <= -10.0f) {
- y[i] = 0.0f;
- } else if (x[i] >= 10.0f) {
- y[i] = x[i] * g[i];
- } else {
- ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
- memcpy(&t, &fp16, sizeof(uint16_t));
- y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[t]) * g[i];
- }
- }
- }
- #else
- inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
- for (int i = 0; i < n; ++i) {
- y[i] = ggml_gelu_f32(x[i]) * g[i];
- }
- }
- #endif
- inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
- const uint16_t * i16 = (const uint16_t *) x;
- for (int i = 0; i < n; ++i) {
- float v = GGML_CPU_FP16_TO_FP32(g[i]);
- y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * v);
- }
- }
- void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g);
- inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
- for (int i = 0; i < n; ++i) {
- float v = GGML_CPU_FP16_TO_FP32(x[i]);
- float w = GGML_CPU_FP16_TO_FP32(g[i]);
- y[i] = GGML_CPU_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
- }
- }
- inline static void ggml_vec_geglu_erf_f32(const int n, float * y, const float * x, const float * g) {
- for (int i = 0; i < n; ++i) {
- float xi = x[i];
- y[i] = 0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * g[i];
- }
- }
- inline static void ggml_vec_geglu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
- for (int i = 0; i < n; ++i) {
- float xi = GGML_CPU_FP16_TO_FP32(x[i]);
- float gi = GGML_CPU_FP16_TO_FP32(g[i]);
- y[i] = GGML_CPU_FP32_TO_FP16(0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * gi);
- }
- }
- #ifdef GGML_GELU_QUICK_FP16
- inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
- uint16_t t;
- for (int i = 0; i < n; ++i) {
- ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
- memcpy(&t, &fp16, sizeof(uint16_t));
- y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]) * g[i];
- }
- }
- #else
- inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
- for (int i = 0; i < n; ++i) {
- y[i] = ggml_gelu_quick_f32(x[i]) * g[i];
- }
- }
- #endif
- inline static void ggml_vec_geglu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
- const uint16_t * i16 = (const uint16_t *) x;
- for (int i = 0; i < n; ++i) {
- float v = GGML_CPU_FP16_TO_FP32(g[i]);
- y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[i16[i]]) * v);
- }
- }
- inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
- #ifndef GGML_USE_ACCELERATE
- ggml_float sum = 0.0;
- for (int i = 0; i < n; ++i) {
- sum += (ggml_float)x[i];
- }
- *s = (float)sum;
- #else
- vDSP_sve(x, 1, s, n);
- #endif
- }
- inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
- ggml_float sum = 0.0;
- for (int i = 0; i < n; ++i) {
- sum += (ggml_float)x[i];
- }
- *s = sum;
- }
- inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) {
- float sum = 0.0f;
- for (int i = 0; i < n; ++i) {
- sum += GGML_CPU_FP16_TO_FP32(x[i]);
- }
- *s = sum;
- }
- inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) {
- float sum = 0.0f;
- for (int i = 0; i < n; ++i) {
- sum += GGML_BF16_TO_FP32(x[i]);
- }
- *s = sum;
- }
- inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
- #ifndef GGML_USE_ACCELERATE
- float max = -INFINITY;
- for (int i = 0; i < n; ++i) {
- max = MAX(max, x[i]);
- }
- *s = max;
- #else
- vDSP_maxv(x, 1, s, n);
- #endif
- }
- inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) {
- ggml_vec_norm_f32(n, s, x);
- *s = 1.f/(*s);
- }
- inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {
- float max = -INFINITY;
- int idx = 0;
- for (int i = 0; i < n; ++i) {
- max = MAX(max, x[i]);
- if (max == x[i]) { idx = i; }
- }
- *s = idx;
- }
- #ifdef __cplusplus
- }
- #endif
|