|
|
@@ -4,6 +4,7 @@
|
|
|
|
|
|
// KleidiAI micro-kernels
|
|
|
#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
|
|
|
+#include "kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h"
|
|
|
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
|
|
|
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
|
|
|
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
|
|
|
@@ -11,20 +12,31 @@
|
|
|
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
|
|
|
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
|
|
|
#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
|
|
|
+#include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
|
|
|
+#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h"
|
|
|
+#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h"
|
|
|
+#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h"
|
|
|
+#include "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h"
|
|
|
+#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
|
|
|
|
|
|
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
|
|
|
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
|
|
|
#include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
|
|
|
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
|
|
|
+#include "kai_lhs_quant_pack_qai8dxp_f32.h"
|
|
|
|
|
|
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
|
|
|
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
|
|
|
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
|
|
|
+#include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
|
|
|
|
|
|
#include "kai_common.h"
|
|
|
|
|
|
#include "simd-mappings.h"
|
|
|
|
|
|
+#define GGML_COMMON_DECL_CPP
|
|
|
+#include "ggml-common.h"
|
|
|
+
|
|
|
#include "kernels.h"
|
|
|
|
|
|
#define NELEMS(x) sizeof(x) / sizeof(*x)
|
|
|
@@ -55,6 +67,14 @@ static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
|
|
|
Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max);
|
|
|
}
|
|
|
|
|
|
+template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
|
|
|
+static inline void kernel_run_float_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
|
|
|
+ const void* lhs, const void* rhs, void* dst,
|
|
|
+ size_t dst_stride_row, size_t dst_stride_col,
|
|
|
+ float clamp_min, float clamp_max) {
|
|
|
+ Fn(m, n, k, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
|
|
|
+}
|
|
|
+
|
|
|
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
|
|
|
static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
|
|
|
return Fn(m, k, bl, mr, kr, sr);
|
|
|
@@ -93,6 +113,12 @@ static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t m
|
|
|
Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
|
|
|
}
|
|
|
|
|
|
+template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const float*,size_t,void*)>
|
|
|
+static inline void lhs_pack_float_fn9_no_bl(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr,
|
|
|
+ size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed) {
|
|
|
+ Fn(m, k, mr, kr, sr, m_idx_start, static_cast<const float*>(lhs), lhs_stride, lhs_packed);
|
|
|
+}
|
|
|
+
|
|
|
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
|
|
|
static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
|
|
|
return Fn(n, k, nr, kr, bl);
|
|
|
@@ -124,6 +150,18 @@ static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t n
|
|
|
static_cast<const kai_rhs_pack_qs4cxs1s0_param*>(params));
|
|
|
}
|
|
|
|
|
|
+template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const int8_t*,const float*,const float*,void*,size_t,const struct kai_rhs_pack_qsi8cx_params*)>
|
|
|
+static inline void rhs_pack_scale_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
|
|
|
+ size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale,
|
|
|
+ void* rhs_packed, size_t extra_bytes, const void* params) {
|
|
|
+ Fn(num_groups, n, k, nr, kr, sr,
|
|
|
+ static_cast<const int8_t*>(rhs),
|
|
|
+ static_cast<const float*>(bias),
|
|
|
+ static_cast<const float*>(scale),
|
|
|
+ rhs_packed, extra_bytes,
|
|
|
+ static_cast<const kai_rhs_pack_qsi8cx_params*>(params));
|
|
|
+}
|
|
|
+
|
|
|
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,const void*,const void*,void*,size_t,const void*)>
|
|
|
static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
|
|
|
size_t rhs_stride, const void* rhs, const void* bias, const void* scale,
|
|
|
@@ -213,6 +251,57 @@ static void dequantize_row_qsi4c32ps1s0scalef16(
|
|
|
GGML_UNUSED(kr);
|
|
|
}
|
|
|
|
|
|
+static void dequantize_row_qsi8cxp(
|
|
|
+ const void *packed_data,
|
|
|
+ int32_t row_idx,
|
|
|
+ int64_t k,
|
|
|
+ float *out,
|
|
|
+ size_t nr,
|
|
|
+ size_t packed_row_stride,
|
|
|
+ size_t kr,
|
|
|
+ size_t bl,
|
|
|
+ size_t num_bytes_multiplier
|
|
|
+) {
|
|
|
+ GGML_UNUSED(bl);
|
|
|
+ GGML_UNUSED(num_bytes_multiplier);
|
|
|
+
|
|
|
+ const size_t k_internal = ((size_t) k + QK8_0 - 1) / QK8_0 * QK8_0;
|
|
|
+ const size_t group_idx = row_idx / nr;
|
|
|
+ const size_t row_in_group = row_idx % nr;
|
|
|
+
|
|
|
+ const uint8_t * group_ptr = static_cast<const uint8_t *>(packed_data) + group_idx * packed_row_stride;
|
|
|
+ const int8_t * data_base = reinterpret_cast<const int8_t *>(group_ptr);
|
|
|
+
|
|
|
+ const size_t num_blocks = k_internal / kr;
|
|
|
+
|
|
|
+ for (size_t block = 0; block < num_blocks; ++block) {
|
|
|
+ const int8_t * block_ptr = data_base + (block * nr + row_in_group) * kr;
|
|
|
+ for (size_t i = 0; i < kr; ++i) {
|
|
|
+ const size_t k_idx = block * kr + i;
|
|
|
+ if (k_idx < (size_t) k) {
|
|
|
+ out[k_idx] = static_cast<float>(block_ptr[i]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ const uint8_t * sums_ptr = group_ptr + nr * k_internal;
|
|
|
+ GGML_UNUSED(sums_ptr);
|
|
|
+
|
|
|
+ const float * scale_ptr = reinterpret_cast<const float *>(sums_ptr + nr * sizeof(int32_t));
|
|
|
+ const float scale = scale_ptr[row_in_group];
|
|
|
+
|
|
|
+ if (scale == 0.0f) {
|
|
|
+ for (size_t i = 0; i < (size_t) k; ++i) {
|
|
|
+ out[i] = 0.0f;
|
|
|
+ }
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ for (size_t i = 0; i < (size_t) k; ++i) {
|
|
|
+ out[i] *= scale;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
#if defined(__ARM_FEATURE_SME)
|
|
|
{
|
|
|
@@ -548,6 +637,174 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
#endif
|
|
|
};
|
|
|
|
|
|
+static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
|
|
|
+#if defined(__ARM_FEATURE_SME)
|
|
|
+ {
|
|
|
+ /* SME GEMM */
|
|
|
+ {
|
|
|
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
|
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
|
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
|
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
|
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
|
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
|
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
|
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
|
+ /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
|
|
|
+ /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
|
|
|
+ /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
|
|
|
+ },
|
|
|
+ /* .gemm_lhs_info = */ {
|
|
|
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
|
|
+ /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
|
|
+ /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
|
|
+ /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
|
|
+ },
|
|
|
+ /* SME GEMV */
|
|
|
+ {
|
|
|
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
|
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
|
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
|
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
|
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
|
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
|
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
|
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
|
+ /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
|
|
|
+ /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
|
|
|
+ /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
|
|
|
+ },
|
|
|
+ /* .gemv_lhs_info = */ {
|
|
|
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
|
|
+ /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
|
|
+ /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
|
|
+ /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
|
|
+ },
|
|
|
+ /* .rhs_info = */ {
|
|
|
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
|
|
|
+ /* .to_float = */ dequantize_row_qsi8cxp,
|
|
|
+ /* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
|
+ /* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
|
+ /* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
|
+ },
|
|
|
+ /* .required_cpu = */ CPU_FEATURE_SME,
|
|
|
+ /* .lhs_type = */ GGML_TYPE_F32,
|
|
|
+ /* .rhs_type = */ GGML_TYPE_Q8_0,
|
|
|
+ /* .op_type = */ GGML_TYPE_F32,
|
|
|
+ },
|
|
|
+#endif
|
|
|
+#if defined(__ARM_FEATURE_MATMUL_INT8)
|
|
|
+ {
|
|
|
+ /* I8MM GEMM */
|
|
|
+ {
|
|
|
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
|
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
|
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
|
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
|
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
|
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
|
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
|
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
|
+ /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
|
|
|
+ /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
|
|
|
+ /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
|
|
|
+ },
|
|
|
+ /* .gemm_lhs_info = */ {
|
|
|
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
|
|
+ /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
|
|
+ /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
|
|
+ /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
|
|
+ },
|
|
|
+ /* I8MM GEMV (dotprod fallback) */
|
|
|
+ {
|
|
|
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
|
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
|
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
|
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
|
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
|
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
|
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
|
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
|
+ /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
|
|
|
+ /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
|
|
|
+ /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
|
|
|
+ },
|
|
|
+ /* .gemv_lhs_info = */ {
|
|
|
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
|
|
+ /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
|
|
+ /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
|
|
+ /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
|
|
+ },
|
|
|
+ /* .rhs_info = */ {
|
|
|
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
|
|
|
+ /* .to_float = */ dequantize_row_qsi8cxp,
|
|
|
+ /* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
|
+ /* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
|
+ /* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
|
+ },
|
|
|
+ /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
|
|
+ /* .lhs_type = */ GGML_TYPE_F32,
|
|
|
+ /* .rhs_type = */ GGML_TYPE_Q8_0,
|
|
|
+ /* .op_type = */ GGML_TYPE_F32,
|
|
|
+ },
|
|
|
+#endif
|
|
|
+#if defined(__ARM_FEATURE_DOTPROD)
|
|
|
+ {
|
|
|
+ /* DOTPROD GEMM */
|
|
|
+ {
|
|
|
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
|
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
|
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
|
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
|
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
|
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
|
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
|
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
|
+ /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
|
|
|
+ /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
|
|
|
+ /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
|
|
|
+ },
|
|
|
+ /* .gemm_lhs_info = */ {
|
|
|
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
|
|
+ /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
|
|
+ /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
|
|
+ /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
|
|
+ },
|
|
|
+ /* DOTPROD GEMV */
|
|
|
+ {
|
|
|
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
|
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
|
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
|
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
|
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
|
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
|
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
|
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
|
+ /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
|
|
|
+ /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
|
|
|
+ /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
|
|
|
+ },
|
|
|
+ /* .gemv_lhs_info = */ {
|
|
|
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
|
|
+ /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
|
|
+ /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
|
|
+ /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
|
|
+ },
|
|
|
+ /* .rhs_info = */ {
|
|
|
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
|
|
|
+ /* .to_float = */ dequantize_row_qsi8cxp,
|
|
|
+ /* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
|
+ /* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
|
+ /* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
|
+ },
|
|
|
+ /* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
|
|
+ /* .lhs_type = */ GGML_TYPE_F32,
|
|
|
+ /* .rhs_type = */ GGML_TYPE_Q8_0,
|
|
|
+ /* .op_type = */ GGML_TYPE_F32,
|
|
|
+ },
|
|
|
+#endif
|
|
|
+};
|
|
|
+
|
|
|
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
|
|
|
ggml_kleidiai_kernels * kernel = nullptr;
|
|
|
|
|
|
@@ -562,6 +819,17 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
|
|
break;
|
|
|
}
|
|
|
}
|
|
|
+ if (!kernel) {
|
|
|
+ for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
|
|
|
+ if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu &&
|
|
|
+ gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type &&
|
|
|
+ gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type &&
|
|
|
+ gemm_gemv_kernels_q8[i].op_type == tensor->type) {
|
|
|
+ kernel = &gemm_gemv_kernels_q8[i];
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
#endif
|
|
|
}
|
|
|
|
|
|
@@ -582,3 +850,18 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features)
|
|
|
|
|
|
return kernels;
|
|
|
}
|
|
|
+
|
|
|
+ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features) {
|
|
|
+ ggml_kleidiai_kernels * kernels = nullptr;
|
|
|
+
|
|
|
+#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
|
|
+ for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
|
|
|
+ if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) {
|
|
|
+ kernels = &gemm_gemv_kernels_q8[i];
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+#endif
|
|
|
+
|
|
|
+ return kernels;
|
|
|
+}
|