|
|
@@ -1630,7 +1630,7 @@ void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int6
|
|
|
// ===================== Helper functions
|
|
|
//
|
|
|
static inline int nearest_int(float fval) {
|
|
|
- assert(fval <= 4194303.f);
|
|
|
+ assert(fabsf(fval) <= 4194303.f);
|
|
|
float val = fval + 12582912.f;
|
|
|
int i; memcpy(&i, &val, sizeof(int));
|
|
|
return (i & 0x007fffff) - 0x00400000;
|
|
|
@@ -3306,6 +3306,191 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr
|
|
|
return nrow * row_size;
|
|
|
}
|
|
|
|
|
|
+// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
|
|
|
+
|
|
|
+void quantize_row_tq1_0_ref(const float * restrict x, block_tq1_0 * restrict y, int64_t k) {
|
|
|
+ assert(k % QK_K == 0);
|
|
|
+ const int64_t nb = k / QK_K;
|
|
|
+
|
|
|
+ for (int64_t i = 0; i < nb; i++) {
|
|
|
+ float amax = 0.0f; // absolute max
|
|
|
+
|
|
|
+ for (int j = 0; j < QK_K; j++) {
|
|
|
+ const float v = x[j];
|
|
|
+ amax = MAX(amax, fabsf(v));
|
|
|
+ }
|
|
|
+
|
|
|
+ const float d = amax;
|
|
|
+ const float id = d ? 1.0f/d : 0.0f;
|
|
|
+
|
|
|
+ y[i].d = GGML_FP32_TO_FP16(d);
|
|
|
+
|
|
|
+ // 5 elements per byte, along 32 bytes
|
|
|
+ for (size_t j = 0; j < sizeof(y->qs) - sizeof(y->qs) % 32; j += 32) {
|
|
|
+ for (size_t m = 0; m < 32; ++m) {
|
|
|
+ uint8_t q = 0;
|
|
|
+ for (size_t n = 0; n < 5; ++n) {
|
|
|
+ int xi = lroundf(x[m + n*32] * id) + 1; // -1, 0, 1 -> 0, 1, 2
|
|
|
+ q *= 3;
|
|
|
+ q += xi;
|
|
|
+ }
|
|
|
+ // ceiling division (243 == pow(3, 5))
|
|
|
+ q = ((uint16_t)q * 256 + (243 - 1)) / 243;
|
|
|
+ y[i].qs[j + m] = q;
|
|
|
+ }
|
|
|
+ x += 5*32;
|
|
|
+ }
|
|
|
+ // along 16 bytes
|
|
|
+ for (size_t j = sizeof(y->qs) - sizeof(y->qs) % 32; j < sizeof(y->qs); j += 16) {
|
|
|
+ for (size_t m = 0; m < 16; ++m) {
|
|
|
+ uint8_t q = 0;
|
|
|
+ for (size_t n = 0; n < 5; ++n) {
|
|
|
+ int xi = lroundf(x[m + n*16] * id) + 1; // -1, 0, 1 -> 0, 1, 2
|
|
|
+ q *= 3;
|
|
|
+ q += xi;
|
|
|
+ }
|
|
|
+ // ceiling division (243 == pow(3, 5))
|
|
|
+ q = ((uint16_t)q * 256 + (243 - 1)) / 243;
|
|
|
+ y[i].qs[j + m] = q;
|
|
|
+ }
|
|
|
+ x += 5*16;
|
|
|
+ }
|
|
|
+ // 4 elements per byte
|
|
|
+ for (size_t j = 0; j < sizeof(y->qh); ++j) {
|
|
|
+ uint8_t q = 0;
|
|
|
+ for (size_t m = 0; m < 4; ++m) {
|
|
|
+ // -1, 0, 1 -> 0, 1, 2
|
|
|
+ int xi = lroundf(x[j + m*sizeof(y->qh)] * id) + 1;
|
|
|
+ q *= 3;
|
|
|
+ q += xi;
|
|
|
+ }
|
|
|
+ // shift the first value to the most significant trit
|
|
|
+ q *= 3;
|
|
|
+ // ceiling division (243 == pow(3, 5))
|
|
|
+ q = ((uint16_t)q * 256 + (243 - 1)) / 243;
|
|
|
+ y[i].qh[j] = q;
|
|
|
+ }
|
|
|
+ x += 4*sizeof(y->qh);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+void quantize_row_tq2_0_ref(const float * restrict x, block_tq2_0 * restrict y, int64_t k) {
|
|
|
+ assert(k % QK_K == 0);
|
|
|
+ const int64_t nb = k / QK_K;
|
|
|
+
|
|
|
+ for (int64_t i = 0; i < nb; i++) {
|
|
|
+ float amax = 0.0f; // absolute max
|
|
|
+
|
|
|
+ for (int j = 0; j < QK_K; j++) {
|
|
|
+ const float v = x[j];
|
|
|
+ amax = MAX(amax, fabsf(v));
|
|
|
+ }
|
|
|
+
|
|
|
+ const float d = amax;
|
|
|
+ const float id = d ? 1.0f/d : 0.0f;
|
|
|
+
|
|
|
+ y[i].d = GGML_FP32_TO_FP16(d);
|
|
|
+
|
|
|
+ for (size_t j = 0; j < sizeof(y->qs); j += 32) {
|
|
|
+ for (size_t m = 0; m < 32; ++m) {
|
|
|
+ uint8_t q = 0;
|
|
|
+ for (size_t n = 0; n < 4; ++n) {
|
|
|
+ // -1, 0, 1 -> 0, 1, 2
|
|
|
+ int xi = lroundf(x[m + n*32] * id) + 1;
|
|
|
+ q += (xi & 3) << (2*n);
|
|
|
+ }
|
|
|
+ y[i].qs[j + m] = q;
|
|
|
+ }
|
|
|
+ x += 4*32;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+void quantize_row_tq1_0(const float * restrict x, void * restrict vy, int64_t k) {
|
|
|
+ assert(k % QK_K == 0);
|
|
|
+ block_tq1_0 * restrict y = vy;
|
|
|
+ quantize_row_tq1_0_ref(x, y, k);
|
|
|
+}
|
|
|
+
|
|
|
+void quantize_row_tq2_0(const float * restrict x, void * restrict vy, int64_t k) {
|
|
|
+ assert(k % QK_K == 0);
|
|
|
+ block_tq2_0 * restrict y = vy;
|
|
|
+ quantize_row_tq2_0_ref(x, y, k);
|
|
|
+}
|
|
|
+
|
|
|
+size_t quantize_tq1_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
|
|
+ (void)quant_weights; // not used
|
|
|
+ const size_t row_size = ggml_row_size(GGML_TYPE_TQ1_0, n_per_row);
|
|
|
+ quantize_row_tq1_0(src, dst, (int64_t)nrow*n_per_row);
|
|
|
+ return nrow * row_size;
|
|
|
+}
|
|
|
+
|
|
|
+size_t quantize_tq2_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
|
|
+ (void)quant_weights; // not used
|
|
|
+ const size_t row_size = ggml_row_size(GGML_TYPE_TQ2_0, n_per_row);
|
|
|
+ quantize_row_tq2_0(src, dst, (int64_t)nrow*n_per_row);
|
|
|
+ return nrow * row_size;
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+void dequantize_row_tq1_0(const block_tq1_0 * restrict x, float * restrict y, int64_t k) {
|
|
|
+ assert(k % QK_K == 0);
|
|
|
+ const int64_t nb = k / QK_K;
|
|
|
+
|
|
|
+ const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
|
|
|
+
|
|
|
+ for (int64_t i = 0; i < nb; ++i) {
|
|
|
+
|
|
|
+ const float d = GGML_FP16_TO_FP32(x[i].d);
|
|
|
+
|
|
|
+ for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {
|
|
|
+ for (size_t n = 0; n < 5; ++n) {
|
|
|
+ for (size_t m = 0; m < 32; ++m) {
|
|
|
+ uint8_t q = x[i].qs[j + m] * pow3[n];
|
|
|
+ int16_t xi = ((uint16_t) q * 3) >> 8;
|
|
|
+ *y++ = (float) (xi - 1) * d;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {
|
|
|
+ for (size_t n = 0; n < 5; ++n) {
|
|
|
+ for (size_t m = 0; m < 16; ++m) {
|
|
|
+ uint8_t q = x[i].qs[j + m] * pow3[n];
|
|
|
+ int16_t xi = ((uint16_t) q * 3) >> 8;
|
|
|
+ *y++ = (float) (xi - 1) * d;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ for (size_t n = 0; n < 4; ++n) {
|
|
|
+ for (size_t j = 0; j < sizeof(x->qh); ++j) {
|
|
|
+ uint8_t q = x[i].qh[j] * pow3[n];
|
|
|
+ int16_t xi = ((uint16_t) q * 3) >> 8;
|
|
|
+ *y++ = (float) (xi - 1) * d;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+void dequantize_row_tq2_0(const block_tq2_0 * restrict x, float * restrict y, int64_t k) {
|
|
|
+ assert(k % QK_K == 0);
|
|
|
+ const int64_t nb = k / QK_K;
|
|
|
+
|
|
|
+ for (int64_t i = 0; i < nb; ++i) {
|
|
|
+
|
|
|
+ const float d = GGML_FP16_TO_FP32(x[i].d);
|
|
|
+
|
|
|
+ for (size_t j = 0; j < sizeof(x->qs); j += 32) {
|
|
|
+ for (size_t l = 0; l < 4; ++l) {
|
|
|
+ for (size_t m = 0; m < 32; ++m) {
|
|
|
+ int8_t q = (x[i].qs[j + m] >> (l*2)) & 3;
|
|
|
+ *y++ = (float) (q - 1) * d;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
// ====================== "True" 2-bit (de)-quantization
|
|
|
|
|
|
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) {
|
|
|
@@ -5470,6 +5655,501 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|
|
*s = sumf;
|
|
|
}
|
|
|
|
|
|
+void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
|
|
+ assert(nrc == 1);
|
|
|
+ UNUSED(nrc);
|
|
|
+ UNUSED(bx);
|
|
|
+ UNUSED(by);
|
|
|
+ UNUSED(bs);
|
|
|
+
|
|
|
+ const block_tq1_0 * restrict x = vx;
|
|
|
+ const block_q8_K * restrict y = vy;
|
|
|
+
|
|
|
+ const int nb = n / QK_K;
|
|
|
+
|
|
|
+#if defined(__ARM_NEON)
|
|
|
+ float sumf = 0.0f;
|
|
|
+
|
|
|
+ uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
|
|
|
+
|
|
|
+ const uint8x16_t shift = vld1q_u8(k_shift);
|
|
|
+
|
|
|
+ for (int i = 0; i < nb; ++i) {
|
|
|
+#if defined(__ARM_FEATURE_DOTPROD)
|
|
|
+ int32x4_t sumi0 = vdupq_n_s32(0);
|
|
|
+ int32x4_t sumi1 = vdupq_n_s32(0);
|
|
|
+#else
|
|
|
+ int16x8_t sumi0 = vdupq_n_s16(0);
|
|
|
+ int16x8_t sumi1 = vdupq_n_s16(0);
|
|
|
+#endif
|
|
|
+
|
|
|
+ // first 32 bytes of 5 elements
|
|
|
+ {
|
|
|
+ uint8x16_t qx0 = vld1q_u8(x[i].qs + 0);
|
|
|
+ uint8x16_t qx1 = vld1q_u8(x[i].qs + 16);
|
|
|
+ uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3));
|
|
|
+ uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3));
|
|
|
+ uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9));
|
|
|
+ uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9));
|
|
|
+ uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27));
|
|
|
+ uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27));
|
|
|
+ uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81));
|
|
|
+ uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81));
|
|
|
+
|
|
|
+ // multiply by 3 and keep the 2 bits above 8 bits
|
|
|
+ int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
|
|
|
+ int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
|
|
|
+ int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
|
|
|
+ int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
|
|
|
+ int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
|
|
|
+ int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
|
|
|
+ int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6));
|
|
|
+ int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6));
|
|
|
+ int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6));
|
|
|
+ int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6));
|
|
|
+
|
|
|
+ const int8x16_t qy0 = vld1q_s8(y[i].qs + 0);
|
|
|
+ const int8x16_t qy1 = vld1q_s8(y[i].qs + 16);
|
|
|
+ const int8x16_t qy2 = vld1q_s8(y[i].qs + 32);
|
|
|
+ const int8x16_t qy3 = vld1q_s8(y[i].qs + 48);
|
|
|
+ const int8x16_t qy4 = vld1q_s8(y[i].qs + 64);
|
|
|
+ const int8x16_t qy5 = vld1q_s8(y[i].qs + 80);
|
|
|
+ const int8x16_t qy6 = vld1q_s8(y[i].qs + 96);
|
|
|
+ const int8x16_t qy7 = vld1q_s8(y[i].qs + 112);
|
|
|
+ const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);
|
|
|
+ const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
|
|
|
+
|
|
|
+#if defined(__ARM_FEATURE_DOTPROD)
|
|
|
+ sumi0 = vdotq_s32(sumi0, sqx0, qy0);
|
|
|
+ sumi1 = vdotq_s32(sumi1, sqx1, qy1);
|
|
|
+ sumi0 = vdotq_s32(sumi0, sqx2, qy2);
|
|
|
+ sumi1 = vdotq_s32(sumi1, sqx3, qy3);
|
|
|
+ sumi0 = vdotq_s32(sumi0, sqx4, qy4);
|
|
|
+ sumi1 = vdotq_s32(sumi1, sqx5, qy5);
|
|
|
+ sumi0 = vdotq_s32(sumi0, sqx6, qy6);
|
|
|
+ sumi1 = vdotq_s32(sumi1, sqx7, qy7);
|
|
|
+ sumi0 = vdotq_s32(sumi0, sqx8, qy8);
|
|
|
+ sumi1 = vdotq_s32(sumi1, sqx9, qy9);
|
|
|
+#else
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8));
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9));
|
|
|
+#endif
|
|
|
+ }
|
|
|
+
|
|
|
+ // last 16 bytes of 5-element, along with the 4 bytes of 4 elements
|
|
|
+ {
|
|
|
+ uint8x16_t qx0 = vld1q_u8(x[i].qs + 32);
|
|
|
+ uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3));
|
|
|
+ uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9));
|
|
|
+ uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27));
|
|
|
+ uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81));
|
|
|
+ uint32_t qh;
|
|
|
+ memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
|
|
|
+ uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh));
|
|
|
+ qx5 = vmulq_u8(qx5, shift);
|
|
|
+
|
|
|
+ // multiply by 3 and keep the 2 bits above 8 bits
|
|
|
+ int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
|
|
|
+ int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
|
|
|
+ int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
|
|
|
+ int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
|
|
|
+ int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
|
|
|
+ int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
|
|
|
+
|
|
|
+ const int8x16_t qy0 = vld1q_s8(y[i].qs + 160);
|
|
|
+ const int8x16_t qy1 = vld1q_s8(y[i].qs + 176);
|
|
|
+ const int8x16_t qy2 = vld1q_s8(y[i].qs + 192);
|
|
|
+ const int8x16_t qy3 = vld1q_s8(y[i].qs + 208);
|
|
|
+ const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);
|
|
|
+ const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
|
|
|
+
|
|
|
+#if defined(__ARM_FEATURE_DOTPROD)
|
|
|
+ sumi0 = vdotq_s32(sumi0, sqx0, qy0);
|
|
|
+ sumi1 = vdotq_s32(sumi1, sqx1, qy1);
|
|
|
+ sumi0 = vdotq_s32(sumi0, sqx2, qy2);
|
|
|
+ sumi1 = vdotq_s32(sumi1, sqx3, qy3);
|
|
|
+ sumi0 = vdotq_s32(sumi0, sqx4, qy4);
|
|
|
+ sumi1 = vdotq_s32(sumi1, sqx5, qy5);
|
|
|
+#else
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
|
|
|
+#endif
|
|
|
+ }
|
|
|
+
|
|
|
+ const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
|
|
|
+ const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
|
|
|
+
|
|
|
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
|
+
|
|
|
+#if defined(__ARM_FEATURE_DOTPROD)
|
|
|
+ sumi0 = vaddq_s32(sumi0, sumi1);
|
|
|
+ sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
|
|
|
+
|
|
|
+ sumf += d * (float) vaddvq_s32(sumi0);
|
|
|
+#else
|
|
|
+ sumi0 = vaddq_s16(sumi0, sumi1);
|
|
|
+ sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
|
|
|
+
|
|
|
+ sumf += d * (float) vaddlvq_s16(sumi0);
|
|
|
+#endif
|
|
|
+ }
|
|
|
+
|
|
|
+ *s = sumf;
|
|
|
+
|
|
|
+#elif defined(__AVX2__)
|
|
|
+ __m256 sumf = _mm256_setzero_ps();
|
|
|
+
|
|
|
+ for (int i = 0; i < nb; ++i) {
|
|
|
+ // 16-bit sums
|
|
|
+ __m256i sumi0 = _mm256_setzero_si256();
|
|
|
+ __m256i sumi1 = _mm256_setzero_si256();
|
|
|
+ __m256i sumi2 = _mm256_setzero_si256();
|
|
|
+
|
|
|
+ // first 32 bytes of 5 elements
|
|
|
+ {
|
|
|
+ __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs));
|
|
|
+ // 8-bit multiplies with shifts, masks and adds
|
|
|
+ __m256i qx1 = _mm256_add_epi8(qx0, _mm256_add_epi8(qx0, qx0)); // 1 * 3
|
|
|
+ __m256i qx2 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx0, 3), _mm256_set1_epi8(-8)), qx0); // 1 * 9
|
|
|
+ __m256i qx3 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx1, 3), _mm256_set1_epi8(-8)), qx1); // 3 * 9
|
|
|
+ __m256i qx4 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx2, 3), _mm256_set1_epi8(-8)), qx2); // 9 * 9
|
|
|
+
|
|
|
+ // TODO: can _mm256_mulhi_epu16 be faster even if 16-bits?
|
|
|
+
|
|
|
+ // Cancel the +1 from avg so that it behaves like a halving add
|
|
|
+ qx0 = _mm256_subs_epu8(qx0, _mm256_set1_epi8(1));
|
|
|
+ qx1 = _mm256_subs_epu8(qx1, _mm256_set1_epi8(1));
|
|
|
+ qx2 = _mm256_subs_epu8(qx2, _mm256_set1_epi8(1));
|
|
|
+ qx3 = _mm256_subs_epu8(qx3, _mm256_set1_epi8(1));
|
|
|
+ qx4 = _mm256_subs_epu8(qx4, _mm256_set1_epi8(1));
|
|
|
+ // Multiply by 3 and get the top 2 bits
|
|
|
+ qx0 = _mm256_avg_epu8(qx0, _mm256_avg_epu8(qx0, _mm256_setzero_si256()));
|
|
|
+ qx1 = _mm256_avg_epu8(qx1, _mm256_avg_epu8(qx1, _mm256_setzero_si256()));
|
|
|
+ qx2 = _mm256_avg_epu8(qx2, _mm256_avg_epu8(qx2, _mm256_setzero_si256()));
|
|
|
+ qx3 = _mm256_avg_epu8(qx3, _mm256_avg_epu8(qx3, _mm256_setzero_si256()));
|
|
|
+ qx4 = _mm256_avg_epu8(qx4, _mm256_avg_epu8(qx4, _mm256_setzero_si256()));
|
|
|
+ qx0 = _mm256_and_si256(_mm256_srli_epi16(qx0, 6), _mm256_set1_epi8(3));
|
|
|
+ qx1 = _mm256_and_si256(_mm256_srli_epi16(qx1, 6), _mm256_set1_epi8(3));
|
|
|
+ qx2 = _mm256_and_si256(_mm256_srli_epi16(qx2, 6), _mm256_set1_epi8(3));
|
|
|
+ qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3));
|
|
|
+ qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3));
|
|
|
+
|
|
|
+ const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 0));
|
|
|
+ const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 32));
|
|
|
+ const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 64));
|
|
|
+ const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 96));
|
|
|
+ const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128));
|
|
|
+
|
|
|
+ qx0 = _mm256_maddubs_epi16(qx0, qy0);
|
|
|
+ qx1 = _mm256_maddubs_epi16(qx1, qy1);
|
|
|
+ qx2 = _mm256_maddubs_epi16(qx2, qy2);
|
|
|
+ qx3 = _mm256_maddubs_epi16(qx3, qy3);
|
|
|
+ qx4 = _mm256_maddubs_epi16(qx4, qy4);
|
|
|
+
|
|
|
+ sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
|
|
|
+ sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
|
|
|
+ sumi2 = _mm256_add_epi16(sumi2, qx4);
|
|
|
+ }
|
|
|
+
|
|
|
+ // last 16 bytes of 5-element, along with the 4 bytes of 4 elements
|
|
|
+ {
|
|
|
+ __m128i qx0 = _mm_loadu_si128((const __m128i *) (x[i].qs + 32));
|
|
|
+ uint32_t qh;
|
|
|
+ memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
|
|
|
+ __m256i qx5_l = _mm256_cvtepu8_epi16(_mm_set1_epi32(qh));
|
|
|
+ __m128i qx1 = _mm_add_epi8(qx0, _mm_add_epi8(qx0, qx0)); // 1 * 3
|
|
|
+ __m128i qx2 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx0, 3), _mm_set1_epi8(-8)), qx0); // 1 * 9
|
|
|
+ __m128i qx3 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx1, 3), _mm_set1_epi8(-8)), qx1); // 3 * 9
|
|
|
+ __m128i qx4 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx2, 3), _mm_set1_epi8(-8)), qx2); // 9 * 9
|
|
|
+ __m256i qx01 = MM256_SET_M128I(qx1, qx0);
|
|
|
+ __m256i qx23 = MM256_SET_M128I(qx3, qx2);
|
|
|
+
|
|
|
+ // avx2 does not have 8-bit multiplies, so 16-bit it is.
|
|
|
+ qx5_l = _mm256_mullo_epi16(qx5_l, _mm256_set_epi16(27, 27, 27, 27, 9, 9, 9, 9, 3, 3, 3, 3, 1, 1, 1, 1));
|
|
|
+ qx5_l = _mm256_and_si256(qx5_l, _mm256_set1_epi16(0xFF));
|
|
|
+ __m128i qx5 = _mm_packus_epi16(_mm256_castsi256_si128(qx5_l), _mm256_extracti128_si256(qx5_l, 1));
|
|
|
+
|
|
|
+ __m256i qx45 = MM256_SET_M128I(qx5, qx4);
|
|
|
+
|
|
|
+ // Cancel the +1 from avg so that it behaves like a halving add
|
|
|
+ qx01 = _mm256_subs_epu8(qx01, _mm256_set1_epi8(1));
|
|
|
+ qx23 = _mm256_subs_epu8(qx23, _mm256_set1_epi8(1));
|
|
|
+ qx45 = _mm256_subs_epu8(qx45, _mm256_set1_epi8(1));
|
|
|
+ // Multiply by 3 and get the top 2 bits
|
|
|
+ qx01 = _mm256_avg_epu8(qx01, _mm256_avg_epu8(qx01, _mm256_setzero_si256()));
|
|
|
+ qx23 = _mm256_avg_epu8(qx23, _mm256_avg_epu8(qx23, _mm256_setzero_si256()));
|
|
|
+ qx45 = _mm256_avg_epu8(qx45, _mm256_avg_epu8(qx45, _mm256_setzero_si256()));
|
|
|
+ qx01 = _mm256_and_si256(_mm256_srli_epi16(qx01, 6), _mm256_set1_epi8(3));
|
|
|
+ qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3));
|
|
|
+ qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3));
|
|
|
+
|
|
|
+ const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160));
|
|
|
+ const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192));
|
|
|
+ const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224));
|
|
|
+
|
|
|
+ qx01 = _mm256_maddubs_epi16(qx01, qy01);
|
|
|
+ qx23 = _mm256_maddubs_epi16(qx23, qy23);
|
|
|
+ qx45 = _mm256_maddubs_epi16(qx45, qy45);
|
|
|
+
|
|
|
+ sumi0 = _mm256_add_epi16(sumi0, qx01);
|
|
|
+ sumi1 = _mm256_add_epi16(sumi1, qx23);
|
|
|
+ sumi2 = _mm256_add_epi16(sumi2, qx45);
|
|
|
+ }
|
|
|
+
|
|
|
+ const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
|
|
|
+ const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d));
|
|
|
+
|
|
|
+ sumi0 = _mm256_sub_epi16(sumi0, ysum);
|
|
|
+ sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2));
|
|
|
+ sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
|
|
|
+
|
|
|
+ sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
|
|
|
+ }
|
|
|
+
|
|
|
+ *s = hsum_float_8(sumf);
|
|
|
+
|
|
|
+#else
|
|
|
+ const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
|
|
|
+
|
|
|
+ float sumf = 0.0f;
|
|
|
+
|
|
|
+ for (int i = 0; i < nb; ++i) {
|
|
|
+ int sum = 0;
|
|
|
+
|
|
|
+ for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {
|
|
|
+ for (size_t l = 0; l < 5; ++l) {
|
|
|
+ for (size_t m = 0; m < 32; ++m) {
|
|
|
+ uint8_t q = x[i].qs[j + m] * pow3[l];
|
|
|
+ uint16_t xi = ((uint16_t) q * 3) >> 8;
|
|
|
+ sum += (xi - 1) * y[i].qs[j*5 + l*32 + m];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {
|
|
|
+ for (size_t l = 0; l < 5; ++l) {
|
|
|
+ for (size_t m = 0; m < 16; ++m) {
|
|
|
+ uint8_t q = x[i].qs[j + m] * pow3[l];
|
|
|
+ uint16_t xi = ((uint16_t) q * 3) >> 8;
|
|
|
+ sum += (xi - 1) * y[i].qs[j*5 + l*16 + m];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ for (size_t l = 0; l < 4; ++l) {
|
|
|
+ for (size_t j = 0; j < sizeof(x->qh); ++j) {
|
|
|
+ uint8_t q = x[i].qh[j] * pow3[l];
|
|
|
+ uint16_t xi = ((uint16_t) q * 3) >> 8;
|
|
|
+ sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j];
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ sumf += (float) sum * (GGML_FP16_TO_FP32(x[i].d) * y[i].d);
|
|
|
+ }
|
|
|
+
|
|
|
+ *s = sumf;
|
|
|
+#endif
|
|
|
+}
|
|
|
+
|
|
|
+void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
|
|
+ assert(nrc == 1);
|
|
|
+ UNUSED(nrc);
|
|
|
+ UNUSED(bx);
|
|
|
+ UNUSED(by);
|
|
|
+ UNUSED(bs);
|
|
|
+
|
|
|
+ const block_tq2_0 * restrict x = vx;
|
|
|
+ const block_q8_K * restrict y = vy;
|
|
|
+
|
|
|
+ const int nb = n / QK_K;
|
|
|
+
|
|
|
+#if defined(__ARM_NEON)
|
|
|
+ float sumf = 0.0f;
|
|
|
+
|
|
|
+ const uint8x16_t m3 = vdupq_n_u8(3);
|
|
|
+
|
|
|
+ for (int i = 0; i < nb; ++i) {
|
|
|
+#if defined(__ARM_FEATURE_DOTPROD)
|
|
|
+ int32x4_t sumi0 = vdupq_n_s32(0);
|
|
|
+ int32x4_t sumi1 = vdupq_n_s32(0);
|
|
|
+#else
|
|
|
+ int16x8_t sumi0 = vdupq_n_s16(0);
|
|
|
+ int16x8_t sumi1 = vdupq_n_s16(0);
|
|
|
+#endif
|
|
|
+
|
|
|
+ for (size_t j = 0; j < sizeof(x->qs); j += 32) {
|
|
|
+ uint8x16_t qx0 = vld1q_u8(x[i].qs + j);
|
|
|
+ uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16);
|
|
|
+ uint8x16_t qx2 = vshrq_n_u8(qx0, 2);
|
|
|
+ uint8x16_t qx3 = vshrq_n_u8(qx1, 2);
|
|
|
+ uint8x16_t qx4 = vshrq_n_u8(qx0, 4);
|
|
|
+ uint8x16_t qx5 = vshrq_n_u8(qx1, 4);
|
|
|
+ uint8x16_t qx6 = vshrq_n_u8(qx0, 6);
|
|
|
+ uint8x16_t qx7 = vshrq_n_u8(qx1, 6);
|
|
|
+
|
|
|
+ int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3));
|
|
|
+ int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3));
|
|
|
+ int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3));
|
|
|
+ int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3));
|
|
|
+ int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3));
|
|
|
+ int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3));
|
|
|
+ int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3));
|
|
|
+ int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3));
|
|
|
+
|
|
|
+ const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0);
|
|
|
+ const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16);
|
|
|
+ const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32);
|
|
|
+ const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48);
|
|
|
+ const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64);
|
|
|
+ const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80);
|
|
|
+ const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96);
|
|
|
+ const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);
|
|
|
+
|
|
|
+#if defined(__ARM_FEATURE_DOTPROD)
|
|
|
+ sumi0 = vdotq_s32(sumi0, sqx0, qy0);
|
|
|
+ sumi1 = vdotq_s32(sumi1, sqx1, qy1);
|
|
|
+ sumi0 = vdotq_s32(sumi0, sqx2, qy2);
|
|
|
+ sumi1 = vdotq_s32(sumi1, sqx3, qy3);
|
|
|
+ sumi0 = vdotq_s32(sumi0, sqx4, qy4);
|
|
|
+ sumi1 = vdotq_s32(sumi1, sqx5, qy5);
|
|
|
+ sumi0 = vdotq_s32(sumi0, sqx6, qy6);
|
|
|
+ sumi1 = vdotq_s32(sumi1, sqx7, qy7);
|
|
|
+#else
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
|
|
|
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
|
|
|
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
|
|
|
+#endif
|
|
|
+ }
|
|
|
+
|
|
|
+ const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
|
|
|
+ const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
|
|
|
+
|
|
|
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
|
+
|
|
|
+#if defined(__ARM_FEATURE_DOTPROD)
|
|
|
+ sumi0 = vaddq_s32(sumi0, sumi1);
|
|
|
+ sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
|
|
|
+
|
|
|
+ sumf += d * (float) vaddvq_s32(sumi0);
|
|
|
+#else
|
|
|
+ sumi0 = vaddq_s16(sumi0, sumi1);
|
|
|
+ sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
|
|
|
+
|
|
|
+ sumf += d * (float) vaddlvq_s16(sumi0);
|
|
|
+#endif
|
|
|
+ }
|
|
|
+
|
|
|
+ *s = sumf;
|
|
|
+
|
|
|
+#elif defined(__AVX2__)
|
|
|
+ __m256 sumf = _mm256_setzero_ps();
|
|
|
+
|
|
|
+ for (int i = 0; i < nb; ++i) {
|
|
|
+ // 16-bit sums, because 256*127 still fits
|
|
|
+ __m256i sumi0 = _mm256_setzero_si256();
|
|
|
+ __m256i sumi1 = _mm256_setzero_si256();
|
|
|
+
|
|
|
+ for (size_t j = 0; j < sizeof(x->qs); j += 32) {
|
|
|
+ __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs + j));
|
|
|
+ __m256i qx1 = _mm256_srli_epi16(qx0, 2);
|
|
|
+ __m256i qx2 = _mm256_srli_epi16(qx0, 4);
|
|
|
+ __m256i qx3 = _mm256_srli_epi16(qx0, 6);
|
|
|
+
|
|
|
+ // 0, 1, 2 (should not be 3)
|
|
|
+ qx0 = _mm256_and_si256(qx0, _mm256_set1_epi8(3));
|
|
|
+ qx1 = _mm256_and_si256(qx1, _mm256_set1_epi8(3));
|
|
|
+ qx2 = _mm256_and_si256(qx2, _mm256_set1_epi8(3));
|
|
|
+ qx3 = _mm256_and_si256(qx3, _mm256_set1_epi8(3));
|
|
|
+
|
|
|
+ const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 0));
|
|
|
+ const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32));
|
|
|
+ const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64));
|
|
|
+ const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96));
|
|
|
+
|
|
|
+ qx0 = _mm256_maddubs_epi16(qx0, qy0);
|
|
|
+ qx1 = _mm256_maddubs_epi16(qx1, qy1);
|
|
|
+ qx2 = _mm256_maddubs_epi16(qx2, qy2);
|
|
|
+ qx3 = _mm256_maddubs_epi16(qx3, qy3);
|
|
|
+
|
|
|
+ sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
|
|
|
+ sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
|
|
|
+ }
|
|
|
+
|
|
|
+ const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
|
|
|
+ const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d));
|
|
|
+
|
|
|
+ sumi0 = _mm256_add_epi16(sumi0, sumi1);
|
|
|
+ sumi0 = _mm256_sub_epi16(sumi0, ysum);
|
|
|
+ sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
|
|
|
+
|
|
|
+ sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
|
|
|
+ }
|
|
|
+
|
|
|
+ *s = hsum_float_8(sumf);
|
|
|
+
|
|
|
+#else
|
|
|
+ float sumf = 0.0f;
|
|
|
+
|
|
|
+ for (int i = 0; i < nb; ++i) {
|
|
|
+ int32_t sumi = 0;
|
|
|
+
|
|
|
+ for (size_t j = 0; j < sizeof(x->qs); j += 32) {
|
|
|
+ for (size_t l = 0; l < 4; ++l) {
|
|
|
+ for (size_t k = 0; k < 32; ++k) {
|
|
|
+ sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
|
|
+
|
|
|
+ sumf += (float) sumi * d;
|
|
|
+ }
|
|
|
+
|
|
|
+ *s = sumf;
|
|
|
+#endif
|
|
|
+}
|
|
|
+
|
|
|
void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
|
|
assert(nrc == 1);
|
|
|
UNUSED(nrc);
|
|
|
@@ -14800,6 +15480,14 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
|
|
|
}
|
|
|
}
|
|
|
} break;
|
|
|
+ case GGML_TYPE_TQ1_0:
|
|
|
+ {
|
|
|
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_tq1_0, data, nb);
|
|
|
+ } break;
|
|
|
+ case GGML_TYPE_TQ2_0:
|
|
|
+ {
|
|
|
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_tq2_0, data, nb);
|
|
|
+ } break;
|
|
|
case GGML_TYPE_IQ1_S:
|
|
|
{
|
|
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb);
|