Browse Source

ggml-cpu: arm64: q4_K repack gemm and gemv implementations (i8mm) (#16739)

* Enabled q4_K_8x8_q8_K path on ARM

* wip: I8mm qs multiplication, pending bias

* cpu : arm : REPACK gemm q4_K8x8 implementation

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>

* Guard gemm with proper features, improved superblock scale and min calc

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>

* cpu: arm: Implemented REPACK gemv for Q4_K

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>

* Removed completed TODO

* Fixed missing guards when selecting optimal repack type for Q4_K

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>

* Fixed macro guard for gemv

* Fixed wrong comment in GEMV

* Fixed warning for unused variable

* vdotq_s32 -> ggml_vdotq_s32

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>

* Clang-format issues

* Apply suggestions from code review

Co-authored-by: Diego Devesa <slarengh@gmail.com>

* Removed unnecessary GGML_UNUSED

* Fixed guards in q4_k gemm and gemv (repack)

---------

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>
Co-authored-by: Diego Devesa <slarengh@gmail.com>
Alberto Cabrera Pérez 1 month ago
parent
commit
dbb852b549

+ 0 - 2
ggml/src/ggml-cpu/arch-fallback.h

@@ -51,10 +51,8 @@
 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
 // repack.cpp
 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
-#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
-#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)

+ 388 - 0
ggml/src/ggml-cpu/arch/arm/repack.cpp

@@ -24,6 +24,29 @@
 
 #define UNUSED GGML_UNUSED
 
+static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
+                                             int16x8_t *     out_mins,
+                                             int8_t *        out_scales) {
+    constexpr uint32_t kmask1 = 0x3f3f3f3f;
+    constexpr uint32_t kmask2 = 0x0f0f0f0f;
+    constexpr uint32_t kmask3 = 0x03030303;
+    constexpr uint8_t  scales_size = 12;
+
+    uint32_t sm[3];
+    memcpy(sm, scales_in, scales_size);
+
+    const uint32_t   mins_0_3 = sm[1] & kmask1;
+    const uint32_t   mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
+    const uint32x2_t mins_u32 = { mins_0_3, mins_4_7 };
+
+    *out_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins_u32)));
+
+    uint32_t scales_u32[2];
+    scales_u32[0] = sm[0] & kmask1;
+    scales_u32[1] = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
+    memcpy(out_scales, scales_u32, 8);
+}
+
 void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
     assert(QK8_0 == 32);
     assert(k % QK8_0 == 0);
@@ -474,6 +497,162 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
     ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
 
+void ggml_gemv_q4_K_8x8_q8_K(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    constexpr int qk = QK_K;
+    const int     nb = n / qk;
+
+    constexpr int ncols_interleaved = 8;
+    constexpr int blocklen          = 8;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON)
+    constexpr int    col_pairs = ncols_interleaved / 2;
+    const uint8x16_t m4b       = vdupq_n_u8(0x0f);
+
+    // 1x8 tile = 2 x 4
+    float32x4_t acc_f32[ncols_interleaved / 4];
+
+    const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
+
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
+
+        for (int i = 0; i < ncols_interleaved / 4; i++) {
+            acc_f32[i] = vdupq_n_f32(0);
+        }
+
+        for (int b = 0; b < nb; b++) {
+            float32x4_t q4_d_0     = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));      // d0 d1 d2 d3
+            float32x4_t q4_d_1     = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));  // d4 d5 d6 d7
+            float32x4_t q8_d       = vdupq_n_f32(q8_ptr[b].d);
+            float32x4_t sb_scale_0 = vmulq_f32(q4_d_0, q8_d);
+            float32x4_t sb_scale_1 = vmulq_f32(q4_d_1, q8_d);
+            float32x4_t q4_dmin_0  = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));      // dmin 0..3
+            float32x4_t q4_dmin_1  = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));  // dmin 4..7
+            float32x4_t sb_min_0   = vmulq_f32(q4_dmin_0, q8_d);
+            float32x4_t sb_min_1   = vmulq_f32(q4_dmin_1, q8_d);
+
+            // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
+            int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
+            // 2 sb each iteration
+            int32x4_t acc_lo[col_pairs];
+            int32x4_t acc_hi[col_pairs];
+
+            // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
+            const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
+            int16_t         bsums_arr[8];
+            vst1q_s16(bsums_arr, bsums);
+            for (int sb = 0; sb < QK_K / 64; sb++) {
+                for (int i = 0; i < col_pairs; i++) {
+                    acc_lo[i] = vdupq_n_s32(0);
+                    acc_hi[i] = vdupq_n_s32(0);
+                }
+                // Need scales for the low and high nibbles
+                // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
+                int16x8_t q4sb_mins[2];  // int16 as its needed for bias_acc later
+                int16x8_t q4sb_scales[2];
+                for (int i = 0; i < 2; i++) {
+                    int8_t    aux_q4sb[8];
+                    const int offset = sb * 24 + i * 12;
+                    decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
+                    q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
+                }
+
+                const uint8_t * q4_base = q4_ptr[b].qs + sb * QK_K;
+
+                // Load the 64 quants from q8K duplicated to use vecdots with the interelaved columns
+                // but still need the qs to use the low and hi bits from q4
+                const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
+                int8x16_t      q8_qs[8];
+                for (int i = 0; i < 8; i++) {
+                    q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
+                }
+
+                // Q4s columns iterated in pairs (01, 23, 45, 67)
+                for (int cp = 0; cp < col_pairs; cp++) {
+                    uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_base + 16 * cp);
+                    uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_base + 16 * cp + 64);
+                    uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_base + 16 * cp + 128);
+                    uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_base + 16 * cp + 192);
+
+                    acc_lo[cp] =
+                        ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)), q8_qs[0]);  // 0 .. 7
+                    acc_lo[cp] =
+                        ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)), q8_qs[1]);  // 8 ..15
+                    acc_lo[cp] =
+                        ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)), q8_qs[2]);  // 16..23
+                    acc_lo[cp] =
+                        ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)), q8_qs[3]);  // 24..31
+
+                    acc_hi[cp] =
+                        ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)), q8_qs[4]);  // 32..39
+                    acc_hi[cp] =
+                        ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)), q8_qs[5]);  // 40..47
+                    acc_hi[cp] =
+                        ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)), q8_qs[6]);  // 48..55
+                    acc_hi[cp] =
+                        ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)), q8_qs[7]);  // 56..63
+                }
+
+                // Iterates over a pair of column pairs (4 columns) to use a single 128 register
+                // p = 0 -> 0123  p2 -> 4567
+                for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
+                    int16x4_t   group_scales_lo = p == 0 ? vget_low_s16(q4sb_scales[0]) : vget_high_s16(q4sb_scales[0]);
+                    int16x4_t   group_scales_hi = p == 0 ? vget_low_s16(q4sb_scales[1]) : vget_high_s16(q4sb_scales[1]);
+                    float32x4_t sb_scale        = p == 0 ? sb_scale_0 : sb_scale_1;
+
+                    // 0123 or 4567
+                    // TODO: Single superblock mul at the end of the superblock
+                    float32x4_t sumf_0 =
+                        vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
+                    acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
+
+                    float32x4_t sumf_1 =
+                        vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
+                    acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
+                }
+
+                // Multiply Acc bsum + mins
+                // Each pair of subblocks share the same bsums
+                // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
+                int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
+                int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
+
+                // cols 0-3 bias
+                bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
+                bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
+
+                // cols 4-7 bias
+                bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
+                bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
+            }  // for sb
+
+            acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0);
+            acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_1);
+        }  // for b
+
+        int base = x * ncols_interleaved;
+        vst1q_f32(s + base, acc_f32[0]);
+        vst1q_f32(s + base + 4, acc_f32[1]);
+    }  // for x
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON)
+    ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
+
 void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK8_0;
     const int nb = n / qk;
@@ -1889,3 +2068,212 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
 #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
     ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
+
+void ggml_gemm_q4_K_8x8_q8_K(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    constexpr int qk = QK_K;
+    const int     nb = n / qk;
+
+    constexpr int ncols_interleaved = 8;
+    constexpr int blocklen          = 8;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    constexpr int    q8_k_blocklen = 4;
+    const uint8x16_t m4b           = vdupq_n_u8(0x0f);
+
+    // 8 accumulators: 2 row pairs × 4 col pairs
+    float32x4_t acc_f32[blocklen];
+
+    for (int y = 0; y < nr / q8_k_blocklen; y++) {
+        const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
+
+            for (int i = 0; i < blocklen; i++) {
+                acc_f32[i] = vdupq_n_f32(0);
+            }
+
+            for (int b = 0; b < nb; b++) {
+                // bsums pairs belongs to the same q8_k subblock
+                const int16x8_t bsums[4]{
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
+                };
+                int16_t bsums_arr[4][8];
+                for (int q8_row = 0; q8_row < 4; q8_row++) {
+                    vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
+                }
+
+                int32x4_t sb_acc[4];    // Aux accumulators to store subblock (partial) results
+                int32x4_t acc[8];       // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
+                int32x4_t bias_acc[8];  // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
+                for (int i = 0; i < 8; i++) {
+                    acc[i]      = vdupq_n_s32(0);
+                    bias_acc[i] = vdupq_n_s32(0);
+                }
+
+                for (int sb = 0; sb < QK_K / 64; sb++) {
+                    // Need scales for the low and high nibbles
+                    // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
+                    int8_t    q4sb_scales[2][8];
+                    int16x8_t q4sb_mins[2];  // int16 as its needed for bias_acc later
+                    for (int i = 0; i < 2; i++) {
+                        const int offset = sb * 24 + i * 12;
+                        decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
+                    }
+
+                    // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
+                    const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
+
+                    int8x16_t q8_qs_01[8];
+                    int8x16_t q8_qs_23[8];
+
+                    // Load 32-byte per row pair, 1 subblock each time
+                    for (int i = 0; i < 8; i++) {
+                        const int offset = i * 32;  // 16 for row 01, 16 for row 23
+                        q8_qs_01[i]      = vld1q_s8(q8_base + offset);
+                        q8_qs_23[i]      = vld1q_s8(q8_base + offset + 16);
+                    }
+
+                    const int8x16_t q8s[2][8] = {
+                        { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3],
+                          q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] },
+                        { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3],
+                          q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] },
+                    };
+
+                    // Q4s columns iterated in pairs (01, 23, 45, 67)
+                    for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
+                        for (int i = 0; i < 4; i++) {
+                            sb_acc[i] = vdupq_n_s32(0);
+                        }
+
+                        uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0);    // 0 .. 7 & 32..39
+                        uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64);   // 8 ..15 & 40..47
+                        uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128);  // 16..23 & 48..55
+                        uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192);  // 24..31 & 56..63
+                        const int8x16_t q4_nibbles[2][4] = {
+                            {
+                                vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)),
+                                vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)),
+                                vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)),
+                                vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)),
+                            },
+                            {
+                                vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)),
+                                vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)),
+                                vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)),
+                                vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)),
+                            }
+                        };
+
+                        // Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8
+                        // for each of the internal 32 qs subblock (blk)
+                        for (int rp = 0; rp < 2; rp++) {
+                            for (int blk = 0; blk < 2; blk++) {
+                                const int8x16_t * q8  = &q8s[rp][4 * blk];
+                                const int8x16_t * q4  = q4_nibbles[blk];
+                                int32x4_t         acc = sb_acc[2 * rp + blk];
+                                // mul add for each qs in the same subblock
+                                for (int qs_offset = 0; qs_offset < 4; qs_offset++) {
+                                    acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]);
+                                }
+                                sb_acc[2 * rp + blk] = acc;
+                            }
+                        }
+
+                        // Scales[i] corresponds to column i
+                        const int scale_offset = cp * 2;
+                        for (int blk = 0; blk < 2; blk++) {
+                            const int32x4_t block_scale = {
+                                (int32_t) q4sb_scales[blk][scale_offset],
+                                (int32_t) q4sb_scales[blk][scale_offset],
+                                (int32_t) q4sb_scales[blk][scale_offset + 1],
+                                (int32_t) q4sb_scales[blk][scale_offset + 1],
+                            };
+                            acc[cp]     = vmlaq_s32(acc[cp], sb_acc[blk], block_scale);
+                            acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale);
+                        }
+                    }
+
+                    // Multiply Acc bsum + mins
+                    for (int q8_row = 0; q8_row < 4; q8_row++) {
+                        // Each pair of subblocks share the same bsums
+                        // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
+                        int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
+                        int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
+
+                        bias_acc[2 * q8_row] =
+                            vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
+                        bias_acc[2 * q8_row] =
+                            vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
+                        bias_acc[2 * q8_row + 1] =
+                            vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
+                        bias_acc[2 * q8_row + 1] =
+                            vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
+                    }
+                }  // for sb
+
+                // Reorder of i8mm output with bias and output layout
+                for (int i = 0; i < 8; i++) {
+                    int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
+                    acc[i]          = vcombine_s32(aux.val[0], aux.val[1]);
+                }
+                int32x4_t reorder_acc[8] = {
+                    vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
+                    vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
+                    vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
+                    vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
+                    vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
+                    vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
+                    vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
+                    vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
+                };
+
+                for (int i = 0; i < q8_k_blocklen; i++) {
+                    for (int j = 0; j < 2; j++) {
+                        float32x4_t       q8_d    = vdupq_n_f32(q8_ptr[b].d[i]);
+                        float32x4_t       q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4)));
+                        const float32x4_t dmins   = vmulq_f32(q4_dmin, q8_d);
+
+                        float32x4_t       q4_d  = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4)));
+                        const float32x4_t scale = vmulq_f32(q4_d, q8_d);
+
+                        acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
+                        acc_f32[2 * i + j] =
+                            vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
+                    }
+                }
+            }  // for b
+
+            // With the previous reorder, the tile is already in the correct memory layout.
+            for (int i = 0; i < q8_k_blocklen; i++) {
+                int row = y * q8_k_blocklen + i;
+                for (int j = 0; j < 2; j++) {
+                    int col    = x * ncols_interleaved + j * 4;
+                    int offset = row * bs + col;
+                    vst1q_f32(s + offset, acc_f32[2 * i + j]);
+                }
+            }
+        }  // for x
+    }  // for y
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}

+ 5 - 0
ggml/src/ggml-cpu/repack.cpp

@@ -1961,6 +1961,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
                 return &q4_K_8x8_q8_K;
             }
         }
+        if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+            if (cur->ne[1] % 8 == 0) {
+                return &q4_K_8x8_q8_K;
+            }
+        }
     } else if (cur->type == GGML_TYPE_Q2_K) {
         if (ggml_cpu_has_avx512()) {
             if (cur->ne[1] % 8 == 0) {