|
|
@@ -2008,6 +2008,45 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
|
|
|
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
|
|
|
|
|
sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
|
|
|
+#elif defined(__ARM_NEON)
|
|
|
+ float sum00 = 0.0f;
|
|
|
+ float sum01 = 0.0f;
|
|
|
+ float sum10 = 0.0f;
|
|
|
+ float sum11 = 0.0f;
|
|
|
+
|
|
|
+ for (int i = 0; i < nb; ++i) {
|
|
|
+ const block_q4_1 * restrict x0 = &x[i + 0];
|
|
|
+ const block_q4_1 * restrict y0 = &y[i + 0];
|
|
|
+
|
|
|
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
|
|
|
+
|
|
|
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
|
|
+ const uint8x16_t v1_0 = vld1q_u8(y0->qs);
|
|
|
+
|
|
|
+ // and with 0xf
|
|
|
+ const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
|
|
|
+ const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
|
|
|
+
|
|
|
+ const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
|
|
|
+ const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
|
|
|
+
|
|
|
+ // dot product into uint16x8_t
|
|
|
+ const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
|
|
|
+ const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
|
|
|
+
|
|
|
+ const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
|
|
|
+ const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
|
|
|
+
|
|
|
+ const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
|
|
|
+ const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
|
|
|
+
|
|
|
+ sum00 += x0->m*y0->m;
|
|
|
+ sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
|
|
|
+ sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
|
|
|
+ sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
|
|
|
+ }
|
|
|
+
|
|
|
+ sumf = QK*sum00 + sum01 + sum10 + sum11;
|
|
|
#else
|
|
|
// scalar
|
|
|
for (int i = 0; i < nb; i++) {
|