Sfoglia il codice sorgente

ggml : add ARM_NEON ggml_vec_dot_q4_1()

Georgi Gerganov 2 anni fa
parent
commit
3b44d30d9b
1 ha cambiato i file con 39 aggiunte e 0 eliminazioni
  1. 39 0
      ggml.c

+ 39 - 0
ggml.c

@@ -2008,6 +2008,45 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
     res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
 
     sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
+#elif defined(__ARM_NEON)
+    float sum00 = 0.0f;
+    float sum01 = 0.0f;
+    float sum10 = 0.0f;
+    float sum11 = 0.0f;
+
+    for (int i = 0; i < nb; ++i) {
+        const block_q4_1 * restrict x0 = &x[i + 0];
+        const block_q4_1 * restrict y0 = &y[i + 0];
+
+        const uint8x16_t m4b = vdupq_n_u8(0xf);
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v1_0 = vld1q_u8(y0->qs);
+
+        // and with 0xf
+        const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
+        const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
+
+        const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
+        const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
+
+        // dot product into uint16x8_t
+        const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
+        const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
+
+        const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
+        const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
+
+        const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
+        const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
+
+        sum00 += x0->m*y0->m;
+        sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
+        sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
+        sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
+    }
+
+    sumf = QK*sum00 + sum01 + sum10 + sum11;
 #else
     // scalar
     for (int i = 0; i < nb; i++) {