|
@@ -657,9 +657,10 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong
|
|
|
#define QK8_0 32
|
|
#define QK8_0 32
|
|
|
typedef struct {
|
|
typedef struct {
|
|
|
float d; // delta
|
|
float d; // delta
|
|
|
|
|
+ float s; // d * sum(qs[i])
|
|
|
int8_t qs[QK8_0]; // quants
|
|
int8_t qs[QK8_0]; // quants
|
|
|
} block_q8_0;
|
|
} block_q8_0;
|
|
|
-static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
|
|
|
|
|
|
|
+static_assert(sizeof(block_q8_0) == 2*sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
|
|
|
|
|
|
|
|
|
|
|
|
|
// reference implementation for deterministic creation of model files
|
|
// reference implementation for deterministic creation of model files
|
|
@@ -1299,12 +1300,38 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
|
|
|
|
|
|
|
|
y[i].d = d;
|
|
y[i].d = d;
|
|
|
|
|
|
|
|
|
|
+ int sum = 0;
|
|
|
for (int l = 0; l < QK8_0; ++l) {
|
|
for (int l = 0; l < QK8_0; ++l) {
|
|
|
const float v = x[i*QK8_0 + l]*id;
|
|
const float v = x[i*QK8_0 + l]*id;
|
|
|
y[i].qs[l] = roundf(v);
|
|
y[i].qs[l] = roundf(v);
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ sum += y[i].qs[l];
|
|
|
|
|
+ }
|
|
|
|
|
+ y[i].s = d * sum;
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+#ifdef __AVX2__
|
|
|
|
|
+// There is no better way of doing this?
|
|
|
|
|
+// I guess not, AVX is not very good at horizontal sums.
|
|
|
|
|
+// The commented solution for a hotrizontal sum was suggested by @pubby as being slightly
|
|
|
|
|
+// faster than the solution below. As I don't have an AVX2 system handt right now to test,
|
|
|
|
|
+// keeping the original.
|
|
|
|
|
+// TODO: Please try and if it does make a differece, uncomment and remove the implementation below.
|
|
|
|
|
+//static inline float horizontal_sum(__m256i a) {
|
|
|
|
|
+// __m256i b = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(a)));
|
|
|
|
|
+// __m256i sum = _mm256_add_epi32(a, b);
|
|
|
|
|
+// __m256i hi = _mm256_unpackhi_epi64(sum, sum);
|
|
|
|
|
+// sum = _mm256_add_epi32(sum, hi);
|
|
|
|
|
+// return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4);
|
|
|
|
|
+//}
|
|
|
|
|
+static inline float horizontal_sum(__m256i a) {
|
|
|
|
|
+ __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extracti128_si256(a, 1));
|
|
|
|
|
+ __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
|
|
|
|
|
+ __m128i sum64 = _mm_add_epi32(hi64, sum128);
|
|
|
|
|
+ __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
|
|
|
|
|
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
|
|
|
}
|
|
}
|
|
|
|
|
+#endif
|
|
|
|
|
|
|
|
static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
|
|
static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
|
|
|
assert(k % QK8_0 == 0);
|
|
assert(k % QK8_0 == 0);
|
|
@@ -1332,6 +1359,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|
|
|
|
|
|
|
y[i].d = d;
|
|
y[i].d = d;
|
|
|
|
|
|
|
|
|
|
+ int32x4_t accv = vdupq_n_s32(0);
|
|
|
|
|
+
|
|
|
for (int l = 0; l < 8; l++) {
|
|
for (int l = 0; l < 8; l++) {
|
|
|
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
|
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
|
|
const int32x4_t vi = vcvtnq_s32_f32(v);
|
|
const int32x4_t vi = vcvtnq_s32_f32(v);
|
|
@@ -1340,7 +1369,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|
|
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
|
|
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
|
|
|
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
|
|
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
|
|
|
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
|
|
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
|
|
|
|
|
+
|
|
|
|
|
+ accv = vaddq_s32(accv, vi);
|
|
|
}
|
|
}
|
|
|
|
|
+ int32_t sum = vaddvq_s32(accv);
|
|
|
|
|
+ y[i].s = d * sum;
|
|
|
}
|
|
}
|
|
|
#elif defined(__AVX2__) || defined(__AVX__)
|
|
#elif defined(__AVX2__) || defined(__AVX__)
|
|
|
for (int i = 0; i < nb; i++) {
|
|
for (int i = 0; i < nb; i++) {
|
|
@@ -1388,6 +1421,10 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|
|
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
|
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
|
|
|
|
|
|
|
#if defined(__AVX2__)
|
|
#if defined(__AVX2__)
|
|
|
|
|
+
|
|
|
|
|
+ // Compute the sum of the quants and set y[i].s
|
|
|
|
|
+ y[i].s = d * horizontal_sum(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
|
|
|
|
|
+
|
|
|
// Convert int32 to int16
|
|
// Convert int32 to int16
|
|
|
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
|
|
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
|
|
|
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
|
|
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
|
|
@@ -1430,6 +1467,14 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|
|
// scalar
|
|
// scalar
|
|
|
quantize_row_q8_0_reference(x, y, k);
|
|
quantize_row_q8_0_reference(x, y, k);
|
|
|
#endif
|
|
#endif
|
|
|
|
|
+#if defined __AVX__
|
|
|
|
|
+ // TODO: vectorize this
|
|
|
|
|
+ for (int i=0; i<nb; ++i) {
|
|
|
|
|
+ int sum = 0;
|
|
|
|
|
+ for (int l=0; l<QK8_0; ++l) sum += y[i].qs[l];
|
|
|
|
|
+ y[i].s = y[i].d * sum;
|
|
|
|
|
+ }
|
|
|
|
|
+#endif
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
|
|
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
|
|
@@ -2372,14 +2417,17 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
|
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
|
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
|
|
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
|
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
|
|
|
|
|
|
|
|
|
+ float sum8 = 0;
|
|
|
|
|
+
|
|
|
for (int i = 0; i < nb; i += 2) {
|
|
for (int i = 0; i < nb; i += 2) {
|
|
|
const block_q4_0 * restrict x0 = &x[i + 0];
|
|
const block_q4_0 * restrict x0 = &x[i + 0];
|
|
|
const block_q4_0 * restrict x1 = &x[i + 1];
|
|
const block_q4_0 * restrict x1 = &x[i + 1];
|
|
|
const block_q8_0 * restrict y0 = &y[i + 0];
|
|
const block_q8_0 * restrict y0 = &y[i + 0];
|
|
|
const block_q8_0 * restrict y1 = &y[i + 1];
|
|
const block_q8_0 * restrict y1 = &y[i + 1];
|
|
|
|
|
|
|
|
|
|
+ sum8 += x0->d * y0->s + x1->d * y1->s;
|
|
|
|
|
+
|
|
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
|
|
- const int8x16_t s8b = vdupq_n_s8(0x8);
|
|
|
|
|
|
|
|
|
|
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
|
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
|
|
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
|
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
|
@@ -2390,12 +2438,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
|
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
|
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
|
|
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
|
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
|
|
|
|
|
|
|
- // sub 8
|
|
|
|
|
- const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
|
|
|
|
- const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
|
|
|
|
- const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
|
|
|
|
- const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
|
|
|
|
-
|
|
|
|
|
// load y
|
|
// load y
|
|
|
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
|
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
|
|
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
|
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
|
@@ -2410,21 +2452,21 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
|
|
|
|
|
|
#if defined(__ARM_FEATURE_DOTPROD)
|
|
#if defined(__ARM_FEATURE_DOTPROD)
|
|
|
// dot product into int32x4_t
|
|
// dot product into int32x4_t
|
|
|
- const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
|
|
|
|
|
- const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
|
|
|
|
|
|
|
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
|
|
|
|
|
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
|
|
|
|
|
|
|
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
|
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
|
|
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
|
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
|
|
|
#else
|
|
#else
|
|
|
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
|
|
|
|
|
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
|
|
|
|
|
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
|
|
|
|
|
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
|
|
|
|
|
|
|
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
|
|
|
|
|
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
|
|
|
|
|
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
|
|
|
|
|
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
|
|
|
|
|
|
|
|
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
|
|
|
|
|
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
|
|
|
|
|
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
|
|
|
|
|
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
|
|
|
|
|
|
|
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
|
|
|
|
|
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
|
|
|
|
|
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
|
|
|
|
|
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
|
|
|
|
|
|
|
|
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
|
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
|
|
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
|
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
|
@@ -2436,7 +2478,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
|
#endif
|
|
#endif
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
|
|
|
|
|
|
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8;
|
|
|
#elif defined(__AVX2__)
|
|
#elif defined(__AVX2__)
|
|
|
// Initialize accumulator with zeros
|
|
// Initialize accumulator with zeros
|
|
|
__m256 acc = _mm256_setzero_ps();
|
|
__m256 acc = _mm256_setzero_ps();
|
|
@@ -2569,12 +2611,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
|
|
|
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
|
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
|
|
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
|
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
|
|
|
|
|
|
|
|
|
+ float summs = 0;
|
|
|
|
|
+
|
|
|
for (int i = 0; i < nb; i += 2) {
|
|
for (int i = 0; i < nb; i += 2) {
|
|
|
const block_q4_1 * restrict x0 = &x[i + 0];
|
|
const block_q4_1 * restrict x0 = &x[i + 0];
|
|
|
const block_q4_1 * restrict x1 = &x[i + 1];
|
|
const block_q4_1 * restrict x1 = &x[i + 1];
|
|
|
const block_q8_0 * restrict y0 = &y[i + 0];
|
|
const block_q8_0 * restrict y0 = &y[i + 0];
|
|
|
const block_q8_0 * restrict y1 = &y[i + 1];
|
|
const block_q8_0 * restrict y1 = &y[i + 1];
|
|
|
|
|
|
|
|
|
|
+ summs += x0->m * y0->s + x1->m * y1->s;
|
|
|
|
|
+
|
|
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
|
|
|
|
|
|
|
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
|
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
|
@@ -2598,17 +2644,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
|
|
|
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
|
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
|
|
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
|
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
|
|
|
|
|
|
|
- const int16x8_t s0i = vaddq_s16(
|
|
|
|
|
- vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
|
|
|
|
|
- vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
|
|
|
|
|
-
|
|
|
|
|
- const int16x8_t s1i = vaddq_s16(
|
|
|
|
|
- vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
|
|
|
|
|
- vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
|
|
|
|
|
-
|
|
|
|
|
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
|
|
|
|
|
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
|
|
|
|
|
-
|
|
|
|
|
#if defined(__ARM_FEATURE_DOTPROD)
|
|
#if defined(__ARM_FEATURE_DOTPROD)
|
|
|
// dot product into int32x4_t
|
|
// dot product into int32x4_t
|
|
|
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
|
|
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
|
|
@@ -2637,24 +2672,26 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
|
|
|
#endif
|
|
#endif
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
|
|
|
|
|
|
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
|
|
|
#elif defined(__AVX2__)
|
|
#elif defined(__AVX2__)
|
|
|
// Initialize accumulator with zeros
|
|
// Initialize accumulator with zeros
|
|
|
__m256 acc = _mm256_setzero_ps();
|
|
__m256 acc = _mm256_setzero_ps();
|
|
|
|
|
|
|
|
|
|
+ float summs = 0;
|
|
|
|
|
+
|
|
|
// Main loop
|
|
// Main loop
|
|
|
for (int i = 0; i < nb; ++i) {
|
|
for (int i = 0; i < nb; ++i) {
|
|
|
const float * d0 = &x[i].d;
|
|
const float * d0 = &x[i].d;
|
|
|
const float * d1 = &y[i].d;
|
|
const float * d1 = &y[i].d;
|
|
|
- const float * m0 = &x[i].m;
|
|
|
|
|
|
|
+ //const float * m0 = &x[i].m;
|
|
|
|
|
+
|
|
|
|
|
+ summs += x[i].m * y[i].s;
|
|
|
|
|
|
|
|
const __m256 d0v = _mm256_broadcast_ss( d0 );
|
|
const __m256 d0v = _mm256_broadcast_ss( d0 );
|
|
|
const __m256 d1v = _mm256_broadcast_ss( d1 );
|
|
const __m256 d1v = _mm256_broadcast_ss( d1 );
|
|
|
- const __m256 m0v = _mm256_broadcast_ss( m0 );
|
|
|
|
|
|
|
|
|
|
// Compute combined scales
|
|
// Compute combined scales
|
|
|
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
|
|
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
|
|
|
- const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
|
|
|
|
|
|
|
|
|
|
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
|
|
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
|
|
|
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
|
|
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
|
|
@@ -2676,15 +2713,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
|
|
|
|
|
|
|
|
// Accumulate d0*d1*x*y
|
|
// Accumulate d0*d1*x*y
|
|
|
acc = _mm256_fmadd_ps( d0d1, xy, acc );
|
|
acc = _mm256_fmadd_ps( d0d1, xy, acc );
|
|
|
-
|
|
|
|
|
- // Compute sum of y values
|
|
|
|
|
- const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
|
|
|
|
|
- const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
|
|
|
|
|
- const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
|
|
|
|
|
- const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
|
|
|
|
|
-
|
|
|
|
|
- // Accumulate d1*m0*y
|
|
|
|
|
- acc = _mm256_fmadd_ps( d1m0, ysum, acc );
|
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// Return horizontal sum of the acc vector
|
|
// Return horizontal sum of the acc vector
|
|
@@ -2693,7 +2721,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
|
|
|
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
|
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
|
|
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
|
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
|
|
|
|
|
|
|
- sumf = _mm_cvtss_f32( res );
|
|
|
|
|
|
|
+ sumf = _mm_cvtss_f32( res ) + summs;
|
|
|
#else
|
|
#else
|
|
|
// scalar
|
|
// scalar
|
|
|
for (int i = 0; i < nb; i++) {
|
|
for (int i = 0; i < nb; i++) {
|