kleidiai.cpp 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612
  1. // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
  2. // SPDX-License-Identifier: MIT
  3. //
  4. #include <arm_neon.h>
  5. #include <assert.h>
  6. #include <atomic>
  7. #include <cfloat>
  8. #include <stdexcept>
  9. #include <stdint.h>
  10. #include <string.h>
  11. #if defined(__linux__)
  12. #include <asm/hwcap.h>
  13. #include <sys/auxv.h>
  14. #elif defined(__APPLE__)
  15. #include <string_view>
  16. #include <sys/sysctl.h>
  17. #include <sys/types.h>
  18. #elif defined(_WIN32)
  19. #include <windows.h>
  20. #include <excpt.h>
  21. #endif
  22. #include "kleidiai.h"
  23. #include "ggml-cpu.h"
  24. #include "ggml-impl.h"
  25. #include "ggml-backend-impl.h"
  26. #include "ggml-threading.h"
  27. #include "traits.h"
  28. #include "kernels.h"
  29. #include "kai_common.h"
  30. #define GGML_COMMON_DECL_CPP
  31. #include "ggml-common.h"
  32. struct ggml_kleidiai_context {
  33. cpu_feature features;
  34. ggml_kleidiai_kernels * kernels;
  35. } static ctx = { CPU_FEATURE_NONE, NULL };
  36. static const char* cpu_feature_to_string(cpu_feature f) {
  37. switch (f) {
  38. case CPU_FEATURE_NONE: return "NONE";
  39. case CPU_FEATURE_DOTPROD: return "DOTPROD";
  40. case CPU_FEATURE_I8MM: return "I8MM";
  41. case CPU_FEATURE_SVE: return "SVE";
  42. case CPU_FEATURE_SME: return "SME";
  43. default: return "UNKNOWN";
  44. }
  45. }
  46. static void init_kleidiai_context(void) {
  47. ggml_critical_section_start();
  48. static bool initialized = false;
  49. if (!initialized) {
  50. initialized = true;
  51. const char *env_var = getenv("GGML_KLEIDIAI_SME");
  52. int sme_enabled = 0;
  53. ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
  54. (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
  55. (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
  56. if (env_var) {
  57. sme_enabled = atoi(env_var);
  58. }
  59. if (sme_enabled != 0) {
  60. ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
  61. }
  62. ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
  63. #ifndef NDEBUG
  64. if (ctx.kernels) {
  65. GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
  66. }
  67. #endif
  68. }
  69. ggml_critical_section_end();
  70. }
  71. static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
  72. GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
  73. return tensor->ne[dim];
  74. }
  75. template <typename Variant, typename Ret, typename... Args, std::size_t... Is>
  76. constexpr bool variant_any_invocable_impl(std::index_sequence<Is...>) {
  77. using V = std::remove_reference_t<Variant>;
  78. return (std::is_invocable_r_v<
  79. Ret,
  80. std::variant_alternative_t<Is, V>,
  81. Args...> || ...);
  82. }
  83. template <typename Variant, typename Ret, typename... Args>
  84. constexpr bool variant_any_invocable_v =
  85. variant_any_invocable_impl<Variant, Ret, Args...>(
  86. std::make_index_sequence<
  87. std::variant_size_v<std::remove_reference_t<Variant>>>{});
  88. template<typename Ret, typename Variant, typename... Args>
  89. static inline Ret variant_call(Variant && var, Args&&... args) {
  90. static_assert(variant_any_invocable_v<std::remove_reference_t<Variant>, Ret, Args...>,
  91. "No alternative in Variant is invocable with the provided arguments and return type.");
  92. return std::visit(
  93. [&](auto && f) -> Ret {
  94. using F = std::decay_t<decltype(f)>;
  95. if constexpr (std::is_invocable_r_v<Ret, F, Args...>) {
  96. return std::invoke(std::forward<decltype(f)>(f), std::forward<Args>(args)...);
  97. } else {
  98. GGML_ABORT("Invalid function type in variant_call");
  99. GGML_UNREACHABLE();
  100. }
  101. },
  102. std::forward<Variant>(var)
  103. );
  104. }
  105. namespace ggml::cpu::kleidiai {
  106. static size_t round_down(size_t x, size_t y) {
  107. return y == 0 ? x : x - (x % y);
  108. }
  109. static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint16_t * src, size_t rhs_stride) {
  110. size_t src_stride = rhs_stride / sizeof(uint16_t);
  111. size_t dst_stride = n;
  112. for (size_t k_idx = 0; k_idx < k; ++k_idx) {
  113. for (size_t n_idx = 0; n_idx < n; ++n_idx) {
  114. uint16_t v = *(src + k_idx + n_idx * src_stride);
  115. *(dst + n_idx + k_idx * dst_stride) = kai_cast_f32_f16(v);
  116. }
  117. }
  118. }
  119. class tensor_traits : public ggml::cpu::tensor_traits {
  120. bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
  121. if (op->op != GGML_OP_MUL_MAT) {
  122. return false;
  123. }
  124. ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
  125. GGML_ASSERT(kernels);
  126. bool is_gemv = op->src[1]->ne[1] == 1;
  127. kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
  128. lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
  129. size_t k = op->src[0]->ne[0];
  130. size_t n = op->src[0]->ne[1];
  131. size_t m = op->src[1]->ne[1];
  132. size_t mr = kernel->get_mr();
  133. size_t kr = kernel->get_kr();
  134. size_t sr = kernel->get_sr();
  135. if (kernels->rhs_type == GGML_TYPE_Q4_0) {
  136. size = variant_call<size_t>(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr);
  137. } else if (kernels->rhs_type == GGML_TYPE_F16) {
  138. const int64_t lhs_batch_size0 = op->src[1]->ne[2];
  139. const int64_t rhs_batch_size0 = op->src[0]->ne[2];
  140. const int64_t r = lhs_batch_size0 / rhs_batch_size0;
  141. size = variant_call<size_t>(lhs_info->packed_size, m * r, k, mr, kr, sr) +
  142. variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
  143. k * n * sizeof(float) + n * sizeof(float);
  144. } else {
  145. GGML_ASSERT(false);
  146. }
  147. return true;
  148. }
  149. bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
  150. if (dst->op == GGML_OP_MUL_MAT) {
  151. if (dst->src[0]->type == GGML_TYPE_Q4_0) {
  152. return compute_forward_q4_0(params, dst);
  153. } else if (dst->src[0]->type == GGML_TYPE_F16) {
  154. return compute_forward_fp16(params, dst);
  155. }
  156. } else if (dst->op == GGML_OP_GET_ROWS) {
  157. if (dst->src[0]->type == GGML_TYPE_Q4_0) {
  158. return compute_forward_get_rows(params, dst);
  159. }
  160. }
  161. return false;
  162. }
  163. bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) {
  164. const ggml_tensor * src0 = dst->src[0];
  165. const ggml_tensor * src1 = dst->src[1];
  166. GGML_TENSOR_BINARY_OP_LOCALS
  167. ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
  168. GGML_ASSERT(kernels);
  169. const bool is_gemv = src1->ne[1] == 1;
  170. kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
  171. lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
  172. GGML_ASSERT(kernel);
  173. const int nth = params->nth;
  174. const int ith = params->ith;
  175. const int64_t lhs_batch_size0 = ne12;
  176. const int64_t rhs_batch_size0 = ne02;
  177. const int64_t batch_size = lhs_batch_size0;
  178. GGML_ASSERT(rhs_batch_size0 > 0);
  179. GGML_ASSERT(lhs_batch_size0 % rhs_batch_size0 == 0);
  180. const int64_t r = lhs_batch_size0 / rhs_batch_size0;
  181. const int64_t m_group = ne11;
  182. const int64_t m = m_group;
  183. const int64_t n = ne01;
  184. const int64_t k = ne00;
  185. const size_t lhs_stride = src1->nb[1];
  186. const size_t rhs_stride = src0->nb[1];
  187. const size_t dst_stride = dst->nb[1];
  188. const int64_t mr = (int64_t) kernel->get_mr();
  189. const int64_t nr = (int64_t) kernel->get_nr();
  190. const int64_t kr = (int64_t) kernel->get_kr();
  191. const int64_t sr = (int64_t) kernel->get_sr();
  192. const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, (size_t)m, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
  193. const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, (size_t)n, (size_t)k);
  194. const size_t kxn_size = (size_t)k * (size_t)n * sizeof(float);
  195. const size_t bias_size = (size_t)n * sizeof(float);
  196. const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
  197. GGML_ASSERT(wsize_required <= params->wsize);
  198. uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
  199. uint8_t * rhs_packed = lhs_packed + lhs_packed_size;
  200. uint8_t * rhs_kxn = rhs_packed + rhs_packed_size;
  201. uint8_t * bias = rhs_kxn + kxn_size;
  202. for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
  203. const int64_t rhs_batch_idx = batch_idx / r;
  204. const uint8_t * rhs_batch_base = static_cast<const uint8_t *>(src0->data) + rhs_batch_idx * src0->nb[2];
  205. uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];
  206. // LHS packing (threaded over m, honoring mr alignment and KV groups)
  207. {
  208. const int64_t m_roundup_mr = kai_roundup(m, mr);
  209. const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth);
  210. if (ith < num_threads) {
  211. const int64_t num_m_per_thread0 = round_down((size_t)(m_roundup_mr / num_threads), (size_t)mr);
  212. const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
  213. const int64_t m_start = ith * num_m_per_thread0;
  214. const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
  215. // Base packed offset (aligned) and per-row stride in bytes
  216. const size_t base_packed_off = variant_call<size_t>(
  217. lhs_info->get_packed_offset, (size_t)m_start, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
  218. const size_t next_block_off = variant_call<size_t>(
  219. lhs_info->get_packed_offset, (size_t)(m_start + mr), (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
  220. const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr;
  221. int64_t remaining = m_count;
  222. int64_t cur = m_start;
  223. while (remaining > 0) {
  224. const int64_t row_in_group = cur;
  225. const int64_t avail = m_group - row_in_group;
  226. const int64_t take = std::min(avail, remaining);
  227. const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];
  228. const void * src_ptr = lhs_batch_base + (size_t)row_in_group * lhs_stride;
  229. const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
  230. void * dst_ptr = lhs_packed + dst_off;
  231. variant_call<void>(lhs_info->pack_func,
  232. (size_t)take, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr,
  233. /*m_idx_start*/ 0, src_ptr, lhs_stride, dst_ptr);
  234. cur += take;
  235. remaining -= take;
  236. }
  237. }
  238. }
  239. // RHS packing (single thread), then synchronize
  240. if (ith == 0) {
  241. memset(bias, 0, (size_t)n * sizeof(float));
  242. transpose_f32kxn_f16nxk((size_t)n, (size_t)k,
  243. reinterpret_cast<float *>(rhs_kxn),
  244. reinterpret_cast<const uint16_t *>(rhs_batch_base),
  245. rhs_stride);
  246. variant_call<void>(kernels->rhs_info.pack_func,
  247. /*num_groups*/ 1, (size_t)n, (size_t)k, (size_t)nr, (size_t)kr, (size_t)sr,
  248. /*rhs_stride (bytes)*/ (size_t)(n * sizeof(float)),
  249. rhs_kxn, bias, nullptr, rhs_packed, /*extra_bytes*/ 0, /*params*/ nullptr);
  250. }
  251. ggml_barrier(params->threadpool);
  252. // Matmul (threaded over n)
  253. {
  254. const int64_t n_step = (int64_t) kernel->get_n_step();
  255. int64_t num_threads_n = KAI_MIN(n / n_step, nth);
  256. if (num_threads_n <= 0) {
  257. num_threads_n = 1;
  258. }
  259. if (ith < num_threads_n) {
  260. const int64_t num_n_per_thread0 = round_down((size_t)(n / num_threads_n), (size_t)n_step);
  261. const int64_t num_n_per_threadN_1 = n - (num_threads_n - 1) * num_n_per_thread0;
  262. const int64_t n_start = ith * num_n_per_thread0;
  263. const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
  264. // LHS packed base at row 0 (consistent with packing above)
  265. const size_t lhs_packed_offset0 = variant_call<size_t>(
  266. lhs_info->get_packed_offset, (size_t)0, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
  267. const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, (size_t)n_start, (size_t)k);
  268. const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);
  269. const void * lhs_ptr = lhs_packed + lhs_packed_offset0;
  270. const void * rhs_ptr = rhs_packed + rhs_packed_offset;
  271. float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
  272. variant_call<void>(kernel->run_kernel,
  273. (size_t)m, (size_t)n_to_process, (size_t)k,
  274. lhs_ptr, rhs_ptr,
  275. dst_ptr, dst_stride, sizeof(float),
  276. -FLT_MAX, FLT_MAX);
  277. }
  278. }
  279. if (batch_idx != batch_size - 1) {
  280. ggml_barrier(params->threadpool);
  281. }
  282. }
  283. return true;
  284. }
  285. bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
  286. GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
  287. const ggml_tensor * src0 = dst->src[0];
  288. const ggml_tensor * src1 = dst->src[1];
  289. GGML_TENSOR_BINARY_OP_LOCALS
  290. ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
  291. GGML_ASSERT(kernels);
  292. bool is_gemv = src1->ne[1] == 1;
  293. kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
  294. lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
  295. GGML_ASSERT(kernel);
  296. const int ith = params->ith;
  297. const int nth_raw = params->nth;
  298. const int nth = nth_raw > 0 ? nth_raw : 1;
  299. const size_t k = ne00;
  300. const size_t m = ne11;
  301. const size_t n = ne01;
  302. size_t mr = kernel->get_mr();
  303. size_t kr = kernel->get_kr();
  304. size_t sr = kernel->get_sr();
  305. const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
  306. uint8_t * lhs_packed = (uint8_t*)params->wdata;
  307. const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
  308. const size_t n_step = kernel->get_n_step();
  309. const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
  310. const size_t n_start = ith * num_n_per_thread;
  311. size_t n_to_process = 0;
  312. if (n_start < n) {
  313. n_to_process = num_n_per_thread;
  314. if ((n_start + n_to_process) > n) {
  315. n_to_process = n - n_start;
  316. }
  317. }
  318. // Calculate number of columns to be processed per thread
  319. const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
  320. const size_t m_start = ith * num_m_per_thread;
  321. size_t m_to_process = num_m_per_thread;
  322. if ((m_start + m_to_process) > m) {
  323. m_to_process = m - m_start;
  324. }
  325. if (m_start < m) {
  326. // Transform LHS
  327. const size_t src_stride = src1->nb[1];
  328. const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
  329. const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr);
  330. void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
  331. variant_call<void>(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
  332. }
  333. ggml_barrier(params->threadpool);
  334. // Perform the operation
  335. const size_t dst_stride = dst->nb[1];
  336. const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr);
  337. const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k, QK4_0);
  338. const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
  339. const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
  340. const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
  341. float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
  342. if (n_to_process > 0) {
  343. variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
  344. sizeof(float), -FLT_MAX, FLT_MAX);
  345. }
  346. return true;
  347. }
  348. bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
  349. GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
  350. GGML_ASSERT(ctx.kernels);
  351. const ggml_tensor * src0 = dst->src[0];
  352. const ggml_tensor * src1 = dst->src[1];
  353. GGML_TENSOR_BINARY_OP_LOCALS
  354. rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
  355. kernel_info * kernel = &ctx.kernels->gemm;
  356. const int64_t nc = ne00;
  357. const int64_t nr = ggml_nelements(src1);
  358. const size_t block_rows = kernel->get_nr();
  359. const size_t kr = kernel->get_kr();
  360. const size_t num_bytes_multiplier = sizeof(uint16_t);
  361. const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
  362. const int ith = params->ith;
  363. const int nth = params->nth;
  364. const int dr = (nr + nth - 1) / nth;
  365. const int ir0 = dr * ith;
  366. const int ir1 = MIN(ir0 + dr, nr);
  367. for (int64_t i = ir0; i < ir1; ++i) {
  368. GGML_ASSERT(src1->type == GGML_TYPE_I32);
  369. int64_t row_idx = ((const int32_t *)src1->data)[i];
  370. GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
  371. float *out = (float *)((char *)dst->data + i * nb1);
  372. rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
  373. }
  374. return true;
  375. }
  376. public:
  377. int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
  378. GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
  379. GGML_ASSERT(ctx.kernels);
  380. const size_t n = tensor->ne[1];
  381. const size_t k = tensor->ne[0];
  382. size_t nr = ctx.kernels->gemm.get_nr();
  383. size_t kr = ctx.kernels->gemm.get_kr();
  384. size_t sr = ctx.kernels->gemm.get_sr();
  385. struct kai_rhs_pack_qs4cxs1s0_param params;
  386. params.lhs_zero_point = 1;
  387. params.rhs_zero_point = 8;
  388. variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, &params);
  389. return 0;
  390. GGML_UNUSED(data_size);
  391. }
  392. };
  393. static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) {
  394. static tensor_traits traits;
  395. return &traits;
  396. }
  397. } // namespace ggml::cpu::kleidiai
  398. static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
  399. tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
  400. return GGML_STATUS_SUCCESS;
  401. GGML_UNUSED(buffer);
  402. }
  403. static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
  404. const void * data, size_t offset, size_t size) {
  405. GGML_ASSERT(offset == 0);
  406. GGML_ASSERT(size == ggml_nbytes(tensor));
  407. auto tensor_traits = (ggml::cpu::kleidiai::tensor_traits *) tensor->extra;
  408. auto OK = tensor_traits->repack(tensor, data, size);
  409. GGML_ASSERT(OK == 0);
  410. GGML_UNUSED(buffer);
  411. }
  412. static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
  413. return "CPU_KLEIDIAI";
  414. GGML_UNUSED(buft);
  415. }
  416. static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
  417. ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
  418. if (buffer == nullptr) {
  419. return nullptr;
  420. }
  421. buffer->buft = buft;
  422. buffer->iface.init_tensor = ggml_backend_cpu_kleidiai_buffer_init_tensor;
  423. buffer->iface.set_tensor = ggml_backend_cpu_kleidiai_buffer_set_tensor;
  424. buffer->iface.get_tensor = nullptr;
  425. buffer->iface.cpy_tensor = nullptr;
  426. return buffer;
  427. }
  428. static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
  429. return TENSOR_ALIGNMENT;
  430. GGML_UNUSED(buft);
  431. }
  432. static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
  433. GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
  434. GGML_ASSERT(ctx.kernels);
  435. const size_t n = tensor->ne[1];
  436. const size_t k = tensor->ne[0];
  437. const size_t nr = ctx.kernels->gemm.get_nr();
  438. const size_t kr = ctx.kernels->gemm.get_kr();
  439. return variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
  440. GGML_UNUSED(buft);
  441. }
  442. namespace ggml::cpu::kleidiai {
  443. class extra_buffer_type : ggml::cpu::extra_buffer_type {
  444. bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
  445. if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
  446. op->src[0]->type == GGML_TYPE_Q4_0 &&
  447. op->src[0]->buffer &&
  448. (ggml_n_dims(op->src[0]) == 2) &&
  449. op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
  450. if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
  451. return false;
  452. }
  453. if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) &&
  454. ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
  455. return true;
  456. }
  457. }
  458. return false;
  459. }
  460. ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
  461. if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
  462. if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
  463. return (ggml::cpu::tensor_traits *) op->src[0]->extra;
  464. }
  465. else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) {
  466. if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
  467. (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
  468. return nullptr;
  469. }
  470. return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
  471. }
  472. }
  473. return nullptr;
  474. }
  475. };
  476. } // namespace ggml::cpu::kleidiai
  477. ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
  478. static ggml::cpu::kleidiai::extra_buffer_type ctx;
  479. static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_kleidiai = {
  480. /* .iface = */ {
  481. /* .get_name = */ ggml_backend_cpu_kleidiai_buffer_type_get_name,
  482. /* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
  483. /* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
  484. /* .get_max_size = */ nullptr, // defaults to SIZE_MAX
  485. /* .get_alloc_size = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size,
  486. /* .is_host = */ nullptr,
  487. },
  488. /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
  489. /* .context = */ &ctx,
  490. };
  491. init_kleidiai_context();
  492. return &ggml_backend_cpu_buffer_type_kleidiai;
  493. }