Ver código fonte

ggml-cpu: ARM64: repack version of q8_0 (dotprod and i8mm) (#18096)

* wip: skeleton for q8_0 repack

* q8_0 repack GEMV implementations

* GEMM implementations

* Formatting

* Fixed format consistency of repack gemm and gemv declarations

* gemv and gemm generic location consistent with declarations

* Removed non-correct unused variables statements

* Cleanup, consistent style

* Missing generic fallbacks for x86 and powerpc
Alberto Cabrera Pérez 1 mês atrás
pai
commit
669696e00d

+ 28 - 0
ggml/src/ggml-cpu/arch-fallback.h

@@ -43,6 +43,8 @@
 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
+#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
@@ -51,6 +53,8 @@
 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
+#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
 // repack.cpp
 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
@@ -67,10 +71,14 @@
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
+#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
+#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
+#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
+#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__POWERPC__) || defined(__powerpc__)
 // ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679
 // quants.c
@@ -91,6 +99,8 @@
 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
+#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
@@ -99,6 +109,8 @@
 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
+#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__loongarch64)
 // quants.c
 #define quantize_row_q8_K_generic quantize_row_q8_K
@@ -119,6 +131,8 @@
 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
+#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
@@ -127,6 +141,8 @@
 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
+#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__riscv)
 // quants.c
 #define quantize_row_q8_K_generic quantize_row_q8_K
@@ -154,6 +170,8 @@
 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
+#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
@@ -161,6 +179,8 @@
 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
+#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__s390x__)
 // quants.c
 #define quantize_row_q8_K_generic quantize_row_q8_K
@@ -187,6 +207,8 @@
 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
+#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
@@ -195,6 +217,8 @@
 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
+#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__wasm__)
 // quants.c
 #define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1
@@ -223,6 +247,8 @@
 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
+#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
@@ -231,4 +257,6 @@
 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
+#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #endif

+ 283 - 0
ggml/src/ggml-cpu/arch/arm/repack.cpp

@@ -786,6 +786,133 @@ void ggml_gemv_q4_K_8x8_q8_K(int                        n,
     ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
 }
 
+void ggml_gemv_q8_0_4x4_q8_0(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    const int qk                = QK8_0;
+    const int nb                = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen          = 4;
+
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
+
+    for (int c = 0; c < nc; c += ncols_interleaved) {
+        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+        float32x4_t        acc   = vdupq_n_f32(0);
+        for (int b = 0; b < nb; b++) {
+            int8x16x4_t b_low  = vld1q_s8_x4((const int8_t *) b_ptr->qs);
+            int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
+            float16x4_t bd     = vld1_f16((const __fp16 *) b_ptr->d);
+
+            int8x16x2_t a  = vld1q_s8_x2(a_ptr->qs);
+            float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
+
+            int32x4_t ret = vdupq_n_s32(0);
+
+            ret = vdotq_laneq_s32(ret, b_low.val[0], a.val[0], 0);
+            ret = vdotq_laneq_s32(ret, b_low.val[1], a.val[0], 1);
+            ret = vdotq_laneq_s32(ret, b_low.val[2], a.val[0], 2);
+            ret = vdotq_laneq_s32(ret, b_low.val[3], a.val[0], 3);
+
+            ret = vdotq_laneq_s32(ret, b_high.val[0], a.val[1], 0);
+            ret = vdotq_laneq_s32(ret, b_high.val[1], a.val[1], 1);
+            ret = vdotq_laneq_s32(ret, b_high.val[2], a.val[1], 2);
+            ret = vdotq_laneq_s32(ret, b_high.val[3], a.val[1], 3);
+
+            acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
+            a_ptr++;
+            b_ptr++;
+        }
+        vst1q_f32(s, acc);
+        s += ncols_interleaved;
+    }
+    return;
+
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    ggml_gemv_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemv_q8_0_4x8_q8_0(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    const int qk                = QK8_0;
+    const int nb                = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen          = 8;
+
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
+
+    for (int c = 0; c < nc; c += ncols_interleaved) {
+        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+        float32x4_t        acc   = vdupq_n_f32(0);
+
+        for (int b = 0; b < nb; b++) {
+            int8x16x4_t b_low  = vld1q_s8_x4((const int8_t *) b_ptr->qs);
+            int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
+            float16x4_t bd     = vld1_f16((const __fp16 *) b_ptr->d);
+
+            int8x8x4_t  a_chunks = vld1_s8_x4(a_ptr->qs);
+            int8x16_t   a0       = vcombine_s8(a_chunks.val[0], a_chunks.val[0]);
+            int8x16_t   a1       = vcombine_s8(a_chunks.val[1], a_chunks.val[1]);
+            int8x16_t   a2       = vcombine_s8(a_chunks.val[2], a_chunks.val[2]);
+            int8x16_t   a3       = vcombine_s8(a_chunks.val[3], a_chunks.val[3]);
+            float16x4_t ad       = vld1_dup_f16((const __fp16 *) &a_ptr->d);
+
+            int32x4_t ret0 = vdupq_n_s32(0);
+            int32x4_t ret1 = vdupq_n_s32(0);
+
+            // 0..7
+            ret0 = vdotq_s32(ret0, b_low.val[0], a0);
+            ret1 = vdotq_s32(ret1, b_low.val[1], a0);
+            // 8..15
+            ret0 = vdotq_s32(ret0, b_low.val[2], a1);
+            ret1 = vdotq_s32(ret1, b_low.val[3], a1);
+            // 16..23
+            ret0 = vdotq_s32(ret0, b_high.val[0], a2);
+            ret1 = vdotq_s32(ret1, b_high.val[1], a2);
+            // 24..31
+            ret0 = vdotq_s32(ret0, b_high.val[2], a3);
+            ret1 = vdotq_s32(ret1, b_high.val[3], a3);
+
+            int32x4_t ret = vpaddq_s32(ret0, ret1);
+
+            acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
+            a_ptr++;
+            b_ptr++;
+        }
+        vst1q_f32(s, acc);
+        s += ncols_interleaved;
+    }
+    return;
+
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    ggml_gemv_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
 void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK8_0;
     const int nb = n / qk;
@@ -2610,3 +2737,159 @@ void ggml_gemm_q4_K_8x8_q8_K(int                        n,
 #endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
     ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
 }
+
+
+void ggml_gemm_q8_0_4x4_q8_0(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    const int qk                = QK8_0;
+    const int nb                = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen          = 4;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
+
+            float32x4_t sumf[4];
+            for (int m = 0; m < 4; m++) {
+                sumf[m] = vdupq_n_f32(0);
+            }
+
+            for (int l = 0; l < nb; l++) {
+                float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *) a_ptr[l].d));
+                float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *) b_ptr[l].d));
+
+                int32x4_t sumi_0 = vdupq_n_s32(0);
+                int32x4_t sumi_1 = vdupq_n_s32(0);
+                int32x4_t sumi_2 = vdupq_n_s32(0);
+                int32x4_t sumi_3 = vdupq_n_s32(0);
+
+                for (int k_group = 0; k_group < 8; k_group += 4) {
+                    int8x16x4_t a = vld1q_s8_x4(a_ptr[l].qs + 16 * k_group);
+                    int8x16x4_t b = vld1q_s8_x4(b_ptr[l].qs + 16 * k_group);
+
+                    for (int k = 0; k < 4; k++) {
+                        sumi_0 = vdotq_laneq_s32(sumi_0, b.val[k], a.val[k], 0);
+                        sumi_1 = vdotq_laneq_s32(sumi_1, b.val[k], a.val[k], 1);
+                        sumi_2 = vdotq_laneq_s32(sumi_2, b.val[k], a.val[k], 2);
+                        sumi_3 = vdotq_laneq_s32(sumi_3, b.val[k], a.val[k], 3);
+                    }
+                }
+
+                sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
+                sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
+                sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
+                sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
+            }
+
+            for (int m = 0; m < 4; m++) {
+                vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
+            }
+        }
+    }
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    ggml_gemm_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_q8_0_4x8_q8_0(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    const int qk                = QK8_0;
+    const int nb                = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen          = 8;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
+
+    for (int y = 0; y < nr; y += 4) {
+        const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb;
+
+        for (int x = 0; x < nc; x += ncols_interleaved) {
+            const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb;
+            const block_q8_0x4 * a_ptr = a_ptr_base;
+
+            float32x4_t acc_f32[4];
+            for (int i = 0; i < 4; i++) {
+                acc_f32[i] = vdupq_n_f32(0);
+            }
+
+            for (int b = 0; b < nb; b++) {
+                int32x4_t acc[4];
+                for (int i = 0; i < 4; i++) {
+                    acc[i] = vdupq_n_s32(0);
+                }
+
+                // Process 4 chunks of 8 positions each
+                for (int chunk = 0; chunk < 4; chunk++) {
+                    int8x16_t a01 = vld1q_s8(a_ptr->qs + chunk * 32);
+                    int8x16_t a23 = vld1q_s8(a_ptr->qs + chunk * 32 + 16);
+                    int8x16_t b01 = vld1q_s8(b_ptr->qs + chunk * 32);
+                    int8x16_t b23 = vld1q_s8(b_ptr->qs + chunk * 32 + 16);
+
+                    acc[0] = vmmlaq_s32(acc[0], a01, b01);
+                    acc[1] = vmmlaq_s32(acc[1], a01, b23);
+                    acc[2] = vmmlaq_s32(acc[2], a23, b01);
+                    acc[3] = vmmlaq_s32(acc[3], a23, b23);
+                }
+
+                // Reorder outputs from 2×2 tiles to row-major
+                // acc[0] = [r0c0, r0c1, r1c0, r1c1]
+                // acc[1] = [r0c2, r0c3, r1c2, r1c3]
+                // acc[2] = [r2c0, r2c1, r3c0, r3c1]
+                // acc[3] = [r2c2, r2c3, r3c2, r3c3]
+                int32x4_t row0 = vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1]));
+                int32x4_t row1 = vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1]));
+                int32x4_t row2 = vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3]));
+                int32x4_t row3 = vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3]));
+
+                // Scales
+                float32x4_t a_d = vcvt_f32_f16(vld1_f16((const __fp16 *) a_ptr->d));
+                float32x4_t b_d = vcvt_f32_f16(vld1_f16((const __fp16 *) b_ptr->d));
+
+                acc_f32[0] = vfmaq_f32(acc_f32[0], vcvtq_f32_s32(row0), vmulq_laneq_f32(b_d, a_d, 0));
+                acc_f32[1] = vfmaq_f32(acc_f32[1], vcvtq_f32_s32(row1), vmulq_laneq_f32(b_d, a_d, 1));
+                acc_f32[2] = vfmaq_f32(acc_f32[2], vcvtq_f32_s32(row2), vmulq_laneq_f32(b_d, a_d, 2));
+                acc_f32[3] = vfmaq_f32(acc_f32[3], vcvtq_f32_s32(row3), vmulq_laneq_f32(b_d, a_d, 3));
+
+                a_ptr++;
+                b_ptr++;
+            }
+
+            for (int row = 0; row < 4; row++) {
+                vst1q_f32(s + (y + row) * bs + x, acc_f32[row]);
+            }
+        }
+    }
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    ggml_gemm_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}

+ 286 - 0
ggml/src/ggml-cpu/repack.cpp

@@ -692,6 +692,100 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
     }
 }
 
+void ggml_gemv_q8_0_4x4_q8_0_generic(int                        n,
+                                     float * GGML_RESTRICT      s,
+                                     size_t                     bs,
+                                     const void * GGML_RESTRICT vx,
+                                     const void * GGML_RESTRICT vy,
+                                     int                        nr,
+                                     int                        nc) {
+    const int qk                = QK8_0;
+    const int nb                = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen          = 4;
+
+    assert(nr == 1);
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(bs);
+    UNUSED(nr);
+
+    float sumf[4];
+    int   sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) {
+            sumf[j] = 0.0;
+        }
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / blocklen); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
+                        sumi += v0 * a_ptr[l].qs[k * blocklen + i];
+                    }
+                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) {
+            s[x * ncols_interleaved + j] = sumf[j];
+        }
+    }
+}
+
+void ggml_gemv_q8_0_4x8_q8_0_generic(int                        n,
+                                     float * GGML_RESTRICT      s,
+                                     size_t                     bs,
+                                     const void * GGML_RESTRICT vx,
+                                     const void * GGML_RESTRICT vy,
+                                     int                        nr,
+                                     int                        nc) {
+    const int qk                = QK8_0;
+    const int nb                = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen          = 8;
+
+    assert(nr == 1);
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(bs);
+    UNUSED(nr);
+
+    float sumf[4];
+    int   sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) {
+            sumf[j] = 0.0;
+        }
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / blocklen); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
+                        sumi += v0 * a_ptr[l].qs[k * blocklen + i];
+                    }
+                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) {
+            s[x * ncols_interleaved + j] = sumf[j];
+        }
+    }
+}
+
 void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK8_0;
     const int nb = n / qk;
@@ -1219,8 +1313,129 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
     }
 }
 
+void ggml_gemm_q8_0_4x4_q8_0_generic(int                        n,
+                                     float * GGML_RESTRICT      s,
+                                     size_t                     bs,
+                                     const void * GGML_RESTRICT vx,
+                                     const void * GGML_RESTRICT vy,
+                                     int                        nr,
+                                     int                        nc) {
+    const int qk                = QK8_0;
+    const int nb                = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen          = 4;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    float sumf[4][4];
+    int   sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumf[m][j] = 0.0;
+                }
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / blocklen); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
+                                sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i];
+                            }
+                            sumf[m][j] +=
+                                sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+                }
+            }
+        }
+    }
+}
+
+void ggml_gemm_q8_0_4x8_q8_0_generic(int                        n,
+                                     float * GGML_RESTRICT      s,
+                                     size_t                     bs,
+                                     const void * GGML_RESTRICT vx,
+                                     const void * GGML_RESTRICT vy,
+                                     int                        nr,
+                                     int                        nc) {
+    const int qk                = QK8_0;
+    const int nb                = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen          = 8;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    float sumf[4][4];
+    int   sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumf[m][j] = 0.0;
+                }
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / blocklen); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
+                                sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i];
+                            }
+                            sumf[m][j] +=
+                                sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+                }
+            }
+        }
+    }
+}
+
 } // extern "C"
 
+static block_q8_0x4 make_block_q8_0x4(block_q8_0 * in, unsigned int blck_size_interleave) {
+    block_q8_0x4 out;
+
+    for (int i = 0; i < 4; i++) {
+        out.d[i] = in[i].d;
+    }
+
+    const int end = QK8_0 * 4 / blck_size_interleave;
+    for (int i = 0; i < end; ++i) {
+        int src_id     = i % 4;
+        int src_offset = (i / 4) * blck_size_interleave;
+        int dst_offset = i * blck_size_interleave;
+        memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave);
+    }
+    return out;
+}
+
 static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
     block_q4_0x4 out;
 
@@ -1534,6 +1749,38 @@ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block
     GGML_UNUSED(data_size);
 }
 
+static int repack_q8_0_to_q8_0_4_bl(struct ggml_tensor *       t,
+                                    int                        interleave_block,
+                                    const void * GGML_RESTRICT data,
+                                    size_t                     data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_Q8_0);
+    GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
+    constexpr int nrows_interleaved = 4;
+
+    block_q8_0x4 *     dst = (block_q8_0x4 *) t->data;
+    const block_q8_0 * src = (const block_q8_0 *) data;
+    block_q8_0         dst_tmp[4];
+    int                nrow    = ggml_nrows(t);
+    int                nblocks = t->ne[0] / QK8_0;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0));
+
+    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_q8_0x4(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+}
+
 static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) {
     block_iq4_nlx4 out;
 
@@ -1702,6 +1949,14 @@ template <> int repack<block_iq4_nl, 8, 8>(struct ggml_tensor * t, const void *
     return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size);
 }
 
+template <> int repack<block_q8_0, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size);
+}
+
+template <> int repack<block_q8_0, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q8_0_to_q8_0_4_bl(t, 8, data, data_size);
+}
+
 // gemv
 template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
 void gemv(int, float *, size_t, const void *, const void *, int, int);
@@ -1738,6 +1993,14 @@ template <> void gemv<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size
     ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
+template <> void gemv<block_q8_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemv<block_q8_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
 // gemm
 template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
 void gemm(int, float *, size_t, const void *, const void *, int, int);
@@ -1774,6 +2037,14 @@ template <> void gemm<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size
     ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
+template <> void gemm<block_q8_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm<block_q8_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
 class tensor_traits_base : public ggml::cpu::tensor_traits {
   public:
     virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
@@ -2168,6 +2439,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
     static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
     static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0> iq4_nl_8x8_q8_0;
 
+    // instance for Q8_0
+    static const ggml::cpu::repack::tensor_traits<block_q8_0, 4, 4, GGML_TYPE_Q8_0> q8_0_4x4_q8_0;
+    static const ggml::cpu::repack::tensor_traits<block_q8_0, 8, 4, GGML_TYPE_Q8_0> q8_0_4x8_q8_0;
+
     if (cur->type == GGML_TYPE_Q4_0) {
         if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)
             || (ggml_cpu_has_riscv_v() && (ggml_cpu_get_rvv_vlen() >= QK4_0))) {
@@ -2218,6 +2493,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
                 return &iq4_nl_4x4_q8_0;
             }
         }
+    } else if (cur->type == GGML_TYPE_Q8_0) {
+        if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+            if (cur->ne[1] % 4 == 0) {
+                return &q8_0_4x8_q8_0;
+            }
+        }
+        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+            if (cur->ne[1] % 4 == 0) {
+                return &q8_0_4x4_q8_0;
+            }
+        }
     }
 
     return nullptr;

+ 8 - 0
ggml/src/ggml-cpu/repack.h

@@ -98,6 +98,10 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
 void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 
 // Native implementations
 void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
@@ -120,6 +124,10 @@ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
 void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 
 #if defined(__cplusplus)
 } // extern "C"