Просмотр исходного кода

kleidiai: add optimized per-channel kernels for Q8_0 (#16993)

Charles Xu 2 месяцев назад
Родитель
Сommit
8c583242ad

+ 15 - 3
ggml/src/ggml-cpu/CMakeLists.txt

@@ -590,6 +590,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
             ${KLEIDIAI_SRC}/kai/ukernels/
             ${KLEIDIAI_SRC}/kai/ukernels/matmul/
             ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
+            ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/
             ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/
             ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
 
@@ -608,23 +609,34 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
             ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c
             ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
             ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
-            ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
+            ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c
+            ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c
+            ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.c)
 
         if (NOT DOTPROD_ENABLED MATCHES -1)
             list(APPEND GGML_KLEIDIAI_SOURCES
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c
-                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c)
         endif()
 
         if (NOT I8MM_ENABLED MATCHES -1)
-            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c)
+            list(APPEND GGML_KLEIDIAI_SOURCES
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c)
         endif()
 
         if (NOT SME_ENABLED MATCHES -1)
             list(APPEND GGML_KLEIDIAI_SOURCES
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.c
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c

+ 283 - 0
ggml/src/ggml-cpu/kleidiai/kernels.cpp

@@ -4,6 +4,7 @@
 
 // KleidiAI micro-kernels
 #include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
+#include "kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h"
 #include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
 #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
 #include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
@@ -11,20 +12,31 @@
 #include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
 #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
 #include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
+#include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
+#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h"
+#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h"
+#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h"
+#include "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h"
+#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
 
 #include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
 #include "kai_lhs_quant_pack_qsi8d32p_f32.h"
 #include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
 #include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
+#include "kai_lhs_quant_pack_qai8dxp_f32.h"
 
 #include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
 #include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
 #include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
+#include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
 
 #include "kai_common.h"
 
 #include "simd-mappings.h"
 
+#define GGML_COMMON_DECL_CPP
+#include "ggml-common.h"
+
 #include "kernels.h"
 
 #define NELEMS(x) sizeof(x) / sizeof(*x)
@@ -55,6 +67,14 @@ static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
     Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max);
 }
 
+template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
+static inline void kernel_run_float_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
+                                     const void* lhs, const void* rhs, void* dst,
+                                     size_t dst_stride_row, size_t dst_stride_col,
+                                     float clamp_min, float clamp_max) {
+    Fn(m, n, k, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
+}
+
 template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
 static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
     return Fn(m, k, bl, mr, kr, sr);
@@ -93,6 +113,12 @@ static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t m
     Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
 }
 
+template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const float*,size_t,void*)>
+static inline void lhs_pack_float_fn9_no_bl(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr,
+                                            size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed) {
+    Fn(m, k, mr, kr, sr, m_idx_start, static_cast<const float*>(lhs), lhs_stride, lhs_packed);
+}
+
 template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
 static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
     return Fn(n, k, nr, kr, bl);
@@ -124,6 +150,18 @@ static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t n
        static_cast<const kai_rhs_pack_qs4cxs1s0_param*>(params));
 }
 
+template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const int8_t*,const float*,const float*,void*,size_t,const struct kai_rhs_pack_qsi8cx_params*)>
+static inline void rhs_pack_scale_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
+                                               size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale,
+                                               void* rhs_packed, size_t extra_bytes, const void* params) {
+    Fn(num_groups, n, k, nr, kr, sr,
+       static_cast<const int8_t*>(rhs),
+       static_cast<const float*>(bias),
+       static_cast<const float*>(scale),
+       rhs_packed, extra_bytes,
+       static_cast<const kai_rhs_pack_qsi8cx_params*>(params));
+}
+
 template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,const void*,const void*,void*,size_t,const void*)>
 static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
                                                size_t rhs_stride, const void* rhs, const void* bias, const void* scale,
@@ -213,6 +251,57 @@ static void dequantize_row_qsi4c32ps1s0scalef16(
     GGML_UNUSED(kr);
 }
 
+static void dequantize_row_qsi8cxp(
+    const void *packed_data,
+    int32_t row_idx,
+    int64_t k,
+    float *out,
+    size_t nr,
+    size_t packed_row_stride,
+    size_t kr,
+    size_t bl,
+    size_t num_bytes_multiplier
+) {
+    GGML_UNUSED(bl);
+    GGML_UNUSED(num_bytes_multiplier);
+
+    const size_t k_internal = ((size_t) k + QK8_0 - 1) / QK8_0 * QK8_0;
+    const size_t group_idx = row_idx / nr;
+    const size_t row_in_group = row_idx % nr;
+
+    const uint8_t * group_ptr = static_cast<const uint8_t *>(packed_data) + group_idx * packed_row_stride;
+    const int8_t  * data_base = reinterpret_cast<const int8_t *>(group_ptr);
+
+    const size_t num_blocks = k_internal / kr;
+
+    for (size_t block = 0; block < num_blocks; ++block) {
+        const int8_t * block_ptr = data_base + (block * nr + row_in_group) * kr;
+        for (size_t i = 0; i < kr; ++i) {
+            const size_t k_idx = block * kr + i;
+            if (k_idx < (size_t) k) {
+                out[k_idx] = static_cast<float>(block_ptr[i]);
+            }
+        }
+    }
+
+    const uint8_t * sums_ptr = group_ptr + nr * k_internal;
+    GGML_UNUSED(sums_ptr);
+
+    const float * scale_ptr = reinterpret_cast<const float *>(sums_ptr + nr * sizeof(int32_t));
+    const float scale = scale_ptr[row_in_group];
+
+    if (scale == 0.0f) {
+        for (size_t i = 0; i < (size_t) k; ++i) {
+            out[i] = 0.0f;
+        }
+        return;
+    }
+
+    for (size_t i = 0; i < (size_t) k; ++i) {
+        out[i] *= scale;
+    }
+}
+
 static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
 #if defined(__ARM_FEATURE_SME)
     {
@@ -548,6 +637,174 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
 #endif
 };
 
+static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
+#if defined(__ARM_FEATURE_SME)
+    {
+        /* SME GEMM */
+        {
+            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
+            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
+            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
+            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
+            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
+            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
+            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
+            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
+            /* .get_lhs_offset_ex     = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
+            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
+            /* .run_kernel_ex         = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
+        },
+        /* .gemm_lhs_info = */ {
+            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
+            /* .get_packed_offset_ex  = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
+            /* .packed_size_ex        = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
+            /* .pack_func_ex          = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
+        },
+        /* SME GEMV */
+        {
+            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
+            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
+            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
+            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
+            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
+            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
+            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
+            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
+            /* .get_lhs_offset_ex     = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
+            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
+            /* .run_kernel_ex         = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
+        },
+        /* .gemv_lhs_info = */ {
+            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
+            /* .get_packed_offset_ex  = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
+            /* .packed_size_ex        = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
+            /* .pack_func_ex          = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
+        },
+        /* .rhs_info = */ {
+            /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
+            /* .to_float              = */ dequantize_row_qsi8cxp,
+            /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
+            /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
+            /* .pack_func_ex          = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
+        },
+        /* .required_cpu       = */ CPU_FEATURE_SME,
+        /* .lhs_type           = */ GGML_TYPE_F32,
+        /* .rhs_type           = */ GGML_TYPE_Q8_0,
+        /* .op_type            = */ GGML_TYPE_F32,
+    },
+#endif
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    {
+        /* I8MM GEMM */
+        {
+            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
+            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
+            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
+            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
+            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
+            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
+            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
+            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
+            /* .get_lhs_offset_ex     = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
+            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
+            /* .run_kernel_ex         = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
+        },
+        /* .gemm_lhs_info = */ {
+            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
+            /* .get_packed_offset_ex  = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
+            /* .packed_size_ex        = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
+            /* .pack_func_ex          = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
+        },
+        /* I8MM GEMV (dotprod fallback) */
+        {
+            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
+            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
+            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
+            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
+            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
+            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
+            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
+            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
+            /* .get_lhs_offset_ex     = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
+            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
+            /* .run_kernel_ex         = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
+        },
+        /* .gemv_lhs_info = */ {
+            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
+            /* .get_packed_offset_ex  = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
+            /* .packed_size_ex        = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
+            /* .pack_func_ex          = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
+        },
+        /* .rhs_info = */ {
+            /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
+            /* .to_float              = */ dequantize_row_qsi8cxp,
+            /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
+            /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
+            /* .pack_func_ex          = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
+        },
+        /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
+        /* .lhs_type           = */ GGML_TYPE_F32,
+        /* .rhs_type           = */ GGML_TYPE_Q8_0,
+        /* .op_type            = */ GGML_TYPE_F32,
+    },
+#endif
+#if defined(__ARM_FEATURE_DOTPROD)
+    {
+        /* DOTPROD GEMM */
+        {
+            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
+            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
+            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
+            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
+            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
+            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
+            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
+            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
+            /* .get_lhs_offset_ex     = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
+            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
+            /* .run_kernel_ex         = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
+        },
+        /* .gemm_lhs_info = */ {
+            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
+            /* .get_packed_offset_ex  = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
+            /* .packed_size_ex        = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
+            /* .pack_func_ex          = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
+        },
+        /* DOTPROD GEMV */
+        {
+            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
+            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
+            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
+            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
+            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
+            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
+            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
+            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
+            /* .get_lhs_offset_ex     = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
+            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
+            /* .run_kernel_ex         = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
+        },
+        /* .gemv_lhs_info = */ {
+            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
+            /* .get_packed_offset_ex  = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
+            /* .packed_size_ex        = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
+            /* .pack_func_ex          = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
+        },
+        /* .rhs_info = */ {
+            /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
+            /* .to_float              = */ dequantize_row_qsi8cxp,
+            /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
+            /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
+            /* .pack_func_ex          = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
+        },
+        /* .required_cpu       = */ CPU_FEATURE_DOTPROD,
+        /* .lhs_type           = */ GGML_TYPE_F32,
+        /* .rhs_type           = */ GGML_TYPE_Q8_0,
+        /* .op_type            = */ GGML_TYPE_F32,
+    },
+#endif
+};
+
 ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
     ggml_kleidiai_kernels * kernel = nullptr;
 
@@ -562,6 +819,17 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
                 break;
             }
         }
+        if (!kernel) {
+            for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
+                if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu &&
+                    gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type &&
+                    gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type &&
+                    gemm_gemv_kernels_q8[i].op_type  == tensor->type) {
+                    kernel = &gemm_gemv_kernels_q8[i];
+                    break;
+                }
+            }
+        }
 #endif
     }
 
@@ -582,3 +850,18 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features)
 
     return kernels;
 }
+
+ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features) {
+    ggml_kleidiai_kernels * kernels = nullptr;
+
+#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
+    for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
+        if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) {
+            kernels = &gemm_gemv_kernels_q8[i];
+            break;
+        }
+    }
+#endif
+
+    return kernels;
+}

+ 1 - 0
ggml/src/ggml-cpu/kleidiai/kernels.h

@@ -87,3 +87,4 @@ struct ggml_kleidiai_kernels {
 
 ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);
 ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);
+ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features);

+ 235 - 34
ggml/src/ggml-cpu/kleidiai/kleidiai.cpp

@@ -5,10 +5,13 @@
 #include <assert.h>
 #include <atomic>
 #include <cfloat>
+#include <cmath>
+#include <algorithm>
 #include <stdexcept>
 #include <stdint.h>
 #include <string.h>
 #include <string>
+#include <vector>
 #if defined(__linux__)
 #include <asm/hwcap.h>
 #include <sys/auxv.h>
@@ -38,8 +41,9 @@
 
 struct ggml_kleidiai_context {
     cpu_feature features;
-    ggml_kleidiai_kernels * kernels;
-} static ctx = { CPU_FEATURE_NONE, NULL };
+    ggml_kleidiai_kernels * kernels_q4;
+    ggml_kleidiai_kernels * kernels_q8;
+} static ctx = { CPU_FEATURE_NONE, NULL, NULL };
 
 static const char* cpu_feature_to_string(cpu_feature f) {
     switch (f) {
@@ -73,10 +77,14 @@ static void init_kleidiai_context(void) {
         if (sme_enabled != 0) {
             ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
         }
-        ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
+        ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features);
+        ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features);
 #ifndef NDEBUG
-        if (ctx.kernels) {
-            GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
+        if (ctx.kernels_q4) {
+            GGML_LOG_DEBUG("kleidiai: using q4 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
+        }
+        if (ctx.kernels_q8) {
+            GGML_LOG_DEBUG("kleidiai: using q8 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
         }
 #endif
     }
@@ -130,6 +138,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         if (kernels->rhs_type == GGML_TYPE_Q4_0) {
             if (!lhs_info->packed_size_ex) return false;
             size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
+        } else if (kernels->rhs_type == GGML_TYPE_Q8_0) {
+            if (!lhs_info->packed_size_ex) return false;
+            size = lhs_info->packed_size_ex(m, k, QK8_0, mr, kr, sr);
         } else if (kernels->rhs_type == GGML_TYPE_F16) {
             if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
             const int64_t lhs_batch_size0 = op->src[1]->ne[2];
@@ -149,11 +160,13 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         if (dst->op == GGML_OP_MUL_MAT) {
             if (dst->src[0]->type == GGML_TYPE_Q4_0) {
                 return compute_forward_q4_0(params, dst);
+            } else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
+                return compute_forward_q8_0(params, dst);
             } else if (dst->src[0]->type == GGML_TYPE_F16) {
                 return compute_forward_fp16(params, dst);
             }
         } else if (dst->op == GGML_OP_GET_ROWS) {
-            if (dst->src[0]->type == GGML_TYPE_Q4_0) {
+            if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
                 return compute_forward_get_rows(params, dst);
             }
         }
@@ -400,19 +413,120 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         return true;
     }
 
-    bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
-        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
-        if (!ctx.kernels) {
+    bool compute_forward_q8_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
+        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q8_0);
+
+        const ggml_tensor * src0 = dst->src[0];
+        const ggml_tensor * src1 = dst->src[1];
+
+        GGML_TENSOR_BINARY_OP_LOCALS
+
+        ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
+        if (!kernels) {
             return false;
         }
 
+        bool is_gemv = src1->ne[1] == 1;
+        kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
+        lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
+
+        if (!kernel || !lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
+            !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
+            return false;
+        }
+
+        const int ith = params->ith;
+        const int nth_raw = params->nth;
+        const int nth = nth_raw > 0 ? nth_raw : 1;
+
+        const size_t k = ne00;
+        const size_t m = ne11;
+        const size_t n = ne01;
+
+        size_t mr = kernel->get_mr();
+        size_t kr = kernel->get_kr();
+        size_t sr = kernel->get_sr();
+
+        const uint8_t * lhs        = static_cast<const uint8_t *>(src1->data);
+        uint8_t * lhs_packed       = static_cast<uint8_t *>(params->wdata);
+        const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
+
+        const size_t n_step = kernel->get_n_step();
+        const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
+        const size_t n_start = ith * num_n_per_thread;
+
+        size_t n_to_process = 0;
+        if (n_start < n) {
+            n_to_process = num_n_per_thread;
+            if ((n_start + n_to_process) > n) {
+                n_to_process = n - n_start;
+            }
+        }
+
+        const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
+        const size_t m_start = ith * num_m_per_thread;
+        size_t m_to_process = num_m_per_thread;
+        if ((m_start + m_to_process) > m) {
+            m_to_process = m - m_start;
+        }
+
+        if (m_start < m) {
+            const size_t src_stride        = src1->nb[1];
+            const float * src_ptr          = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
+            const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
+            void * lhs_packed_ptr          = static_cast<void *>(lhs_packed + lhs_packed_offset);
+
+            lhs_info->pack_func_ex(m_to_process, k, 0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
+        }
+
+        ggml_barrier(params->threadpool);
+
+        const size_t dst_stride        = dst->nb[1];
+        const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
+        const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
+        const size_t dst_offset        = kernel->get_dst_offset(0, n_start, dst_stride);
+        const void * rhs_ptr           = static_cast<const void *>(rhs_packed + rhs_packed_offset);
+        const void * lhs_ptr           = static_cast<const void *>(lhs_packed + lhs_packed_offset);
+        float * dst_ptr                = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
+
+        if (n_to_process > 0) {
+            kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
+                                  sizeof(float), -FLT_MAX, FLT_MAX);
+        }
+
+        return true;
+    }
+
+    bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
         const ggml_tensor * src0 = dst->src[0];
         const ggml_tensor * src1 = dst->src[1];
 
         GGML_TENSOR_BINARY_OP_LOCALS
 
-        rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
-        kernel_info * kernel        = &ctx.kernels->gemm;
+        ggml_kleidiai_kernels * kernels = nullptr;
+        size_t block_len = 0;
+        size_t num_bytes_multiplier = 0;
+
+        if (dst->src[0]->type == GGML_TYPE_Q4_0) {
+            if (!ctx.kernels_q4) {
+                return false;
+            }
+            kernels = ctx.kernels_q4;
+            block_len = QK4_0;
+            num_bytes_multiplier = sizeof(uint16_t);
+        } else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
+            if (!ctx.kernels_q8) {
+                return false;
+            }
+            kernels = ctx.kernels_q8;
+            block_len = QK8_0;
+            num_bytes_multiplier = sizeof(float);
+        } else {
+            return false;
+        }
+
+        rhs_packing_info * rhs_info = &kernels->rhs_info;
+        kernel_info * kernel        = &kernels->gemm;
         if (!rhs_info->to_float || !kernel->get_nr) {
             return false;
         }
@@ -423,8 +537,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         const size_t block_rows = kernel->get_nr();
         const size_t kr         = kernel->get_kr();
 
-        const size_t num_bytes_multiplier = sizeof(uint16_t);
-        const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
+        const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, block_len);
 
         const int ith = params->ith;
         const int nth = params->nth;
@@ -439,7 +552,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
             GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
 
             float *out = (float *)((char *)dst->data + i * nb1);
-            rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
+            rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
         }
 
         return true;
@@ -447,21 +560,91 @@ class tensor_traits : public ggml::cpu::tensor_traits {
 
 public:
     int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
-        GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
-        GGML_ASSERT(ctx.kernels);
         const size_t n = tensor->ne[1];
         const size_t k = tensor->ne[0];
-        size_t nr      = ctx.kernels->gemm.get_nr();
-        size_t kr      = ctx.kernels->gemm.get_kr();
-        size_t sr      = ctx.kernels->gemm.get_sr();
 
-        struct kai_rhs_pack_qs4cxs1s0_param params;
-        params.lhs_zero_point = 1;
-        params.rhs_zero_point = 8;
-        ctx.kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, (const uint8_t*)data, nullptr, nullptr, tensor->data, 0, &params);
+        if (tensor->type == GGML_TYPE_Q4_0) {
+            if (!ctx.kernels_q4) {
+                return -1;
+            }
+            size_t nr = ctx.kernels_q4->gemm.get_nr();
+            size_t kr = ctx.kernels_q4->gemm.get_kr();
+            size_t sr = ctx.kernels_q4->gemm.get_sr();
+
+            struct kai_rhs_pack_qs4cxs1s0_param params;
+            params.lhs_zero_point = 1;
+            params.rhs_zero_point = 8;
+            ctx.kernels_q4->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
+                                                  static_cast<const uint8_t *>(data),
+                                                  nullptr, nullptr, tensor->data, 0, &params);
+            GGML_UNUSED(data_size);
+            return 0;
+        } else if (tensor->type == GGML_TYPE_Q8_0) {
+            if (!ctx.kernels_q8) {
+                return -1;
+            }
+
+            const size_t row_stride = tensor->nb[1];
+            const size_t k_blocks   = (k + QK8_0 - 1) / QK8_0;
+
+            std::vector<int8_t> qdata(n * k, 0);
+            std::vector<float> scales(n, 0.0f);
+
+            for (size_t row = 0; row < n; ++row) {
+                const auto * row_blocks = reinterpret_cast<const block_q8_0 *>(
+                    static_cast<const uint8_t *>(data) + row * row_stride);
+
+                float max_abs = 0.0f;
+                for (size_t block = 0; block < k_blocks; ++block) {
+                    const block_q8_0 & blk = row_blocks[block];
+                    const float d = GGML_FP16_TO_FP32(blk.d);
+                    for (size_t l = 0; l < QK8_0; ++l) {
+                        const size_t linear_idx = block * QK8_0 + l;
+                        if (linear_idx >= k) {
+                            break;
+                        }
+                        const float value = d * blk.qs[l];
+                        max_abs = std::max(max_abs, std::fabs(value));
+                    }
+                }
+
+                float scale = max_abs > 0.0f ? max_abs / 127.0f : 0.0f;
+                scales[row] = scale;
+                const float inv_scale = scale > 0.0f ? 1.0f / scale : 0.0f;
+
+                for (size_t block = 0; block < k_blocks; ++block) {
+                    const block_q8_0 & blk = row_blocks[block];
+                    const float d = GGML_FP16_TO_FP32(blk.d);
+                    for (size_t l = 0; l < QK8_0; ++l) {
+                        const size_t linear_idx = block * QK8_0 + l;
+                        if (linear_idx >= k) {
+                            break;
+                        }
+                        const float value = d * blk.qs[l];
+                        int32_t q = scale > 0.0f ? static_cast<int32_t>(std::lround(value * inv_scale)) : 0;
+                        q = std::clamp(q, -127, 127);
+                        qdata[row * k + linear_idx] = static_cast<int8_t>(q);
+                    }
+                }
+            }
+
+            size_t nr = ctx.kernels_q8->gemm.get_nr();
+            size_t kr = ctx.kernels_q8->gemm.get_kr();
+            size_t sr = ctx.kernels_q8->gemm.get_sr();
+
+            struct kai_rhs_pack_qsi8cx_params params;
+            params.lhs_zero_point = 1;
+            params.scale_multiplier = 1.0f;
+
+            ctx.kernels_q8->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
+                                                  qdata.data(), nullptr, scales.data(),
+                                                  tensor->data, 0, &params);
+            GGML_UNUSED(data_size);
+            return 0;
+        }
 
-        return 0;
         GGML_UNUSED(data_size);
+        return -1;
     }
 };
 
@@ -518,27 +701,45 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
 }
 
 static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
-    GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
-    GGML_ASSERT(ctx.kernels);
+    GGML_UNUSED(buft);
 
-    const size_t n  = tensor->ne[1];
-    const size_t k  = tensor->ne[0];
-    const size_t nr = ctx.kernels->gemm.get_nr();
-    const size_t kr = ctx.kernels->gemm.get_kr();
+    const size_t n = tensor->ne[1];
+    const size_t k = tensor->ne[0];
+
+    ggml_kleidiai_kernels * kernels = nullptr;
+    size_t block_len = 0;
+
+    if (tensor->type == GGML_TYPE_Q4_0) {
+        GGML_ASSERT(ctx.kernels_q4);
+        kernels = ctx.kernels_q4;
+        block_len = QK4_0;
+    } else if (tensor->type == GGML_TYPE_Q8_0) {
+        GGML_ASSERT(ctx.kernels_q8);
+        kernels = ctx.kernels_q8;
+        block_len = QK8_0;
+    } else {
+        return 0;
+    }
 
-    return ctx.kernels->rhs_info.packed_size_ex(n, k, nr, kr, QK4_0);
+    const size_t nr = kernels->gemm.get_nr();
+    const size_t kr = kernels->gemm.get_kr();
+    const size_t packed = kernels->rhs_info.packed_size_ex(n, k, nr, kr, block_len);
+    const size_t raw     = ggml_nbytes(tensor);
 
-    GGML_UNUSED(buft);
+    return packed > raw ? packed : raw;
 }
 
 namespace ggml::cpu::kleidiai {
 class extra_buffer_type : ggml::cpu::extra_buffer_type {
     bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
         if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
-            op->src[0]->type == GGML_TYPE_Q4_0 &&
+            (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) &&
             op->src[0]->buffer &&
             (ggml_n_dims(op->src[0]) == 2) &&
-            op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
+            op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
+            if (((op->src[0]->type == GGML_TYPE_Q4_0) ? ctx.kernels_q4 : ctx.kernels_q8) == nullptr) {
+                return false;
+            }
             if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
                 return false;
             }