kleidiai.cpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569
  1. // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
  2. // SPDX-License-Identifier: MIT
  3. //
  4. #include <arm_neon.h>
  5. #include <assert.h>
  6. #include <atomic>
  7. #include <cfloat>
  8. #include <stdexcept>
  9. #include <stdint.h>
  10. #include <string.h>
  11. #if defined(__linux__)
  12. #include <asm/hwcap.h>
  13. #include <sys/auxv.h>
  14. #elif defined(__APPLE__)
  15. #include <string_view>
  16. #include <sys/sysctl.h>
  17. #include <sys/types.h>
  18. #elif defined(_WIN32)
  19. #include <windows.h>
  20. #include <excpt.h>
  21. #endif
  22. #include "kleidiai.h"
  23. #include "ggml-cpu.h"
  24. #include "ggml-impl.h"
  25. #include "ggml-backend-impl.h"
  26. #include "ggml-threading.h"
  27. #include "traits.h"
  28. #include "kernels.h"
  29. #include "kai_common.h"
  30. #define GGML_COMMON_DECL_CPP
  31. #include "ggml-common.h"
  32. struct ggml_kleidiai_context {
  33. cpu_feature features;
  34. ggml_kleidiai_kernels * kernels;
  35. } static ctx = { CPU_FEATURE_NONE, NULL };
  36. static const char* cpu_feature_to_string(cpu_feature f) {
  37. switch (f) {
  38. case CPU_FEATURE_NONE: return "NONE";
  39. case CPU_FEATURE_DOTPROD: return "DOTPROD";
  40. case CPU_FEATURE_I8MM: return "I8MM";
  41. case CPU_FEATURE_SVE: return "SVE";
  42. case CPU_FEATURE_SME: return "SME";
  43. default: return "UNKNOWN";
  44. }
  45. }
  46. static void init_kleidiai_context(void) {
  47. ggml_critical_section_start();
  48. static bool initialized = false;
  49. if (!initialized) {
  50. initialized = true;
  51. const char *env_var = getenv("GGML_KLEIDIAI_SME");
  52. int sme_enabled = 0;
  53. ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
  54. (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
  55. (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
  56. if (env_var) {
  57. sme_enabled = atoi(env_var);
  58. }
  59. if (sme_enabled != 0) {
  60. ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
  61. }
  62. ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
  63. #ifndef NDEBUG
  64. if (ctx.kernels) {
  65. GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
  66. }
  67. #endif
  68. }
  69. ggml_critical_section_end();
  70. }
  71. static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
  72. GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
  73. return tensor->ne[dim];
  74. }
  75. template<typename Ret, typename Variant, typename... Args>
  76. static Ret variant_call(const Variant & var, Args&&... args) {
  77. return std::visit([&](auto&& func) -> Ret {
  78. if constexpr (std::is_invocable_r_v<Ret, decltype(func), Args...>) {
  79. return func(std::forward<Args>(args)...);
  80. } else {
  81. throw std::runtime_error("Invalid function type in variant_call");
  82. }
  83. }, var);
  84. }
  85. namespace ggml::cpu::kleidiai {
  86. static size_t round_down(size_t x, size_t y) {
  87. return y == 0 ? x : x - (x % y);
  88. }
  89. static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint16_t * src, size_t rhs_stride) {
  90. size_t src_stride = rhs_stride / sizeof(uint16_t);
  91. size_t dst_stride = n;
  92. for (size_t k_idx = 0; k_idx < k; ++k_idx) {
  93. for (size_t n_idx = 0; n_idx < n; ++n_idx) {
  94. uint16_t v = *(src + k_idx + n_idx * src_stride);
  95. *(dst + n_idx + k_idx * dst_stride) = kai_cast_f32_f16(v);
  96. }
  97. }
  98. }
  99. class tensor_traits : public ggml::cpu::tensor_traits {
  100. bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
  101. if (op->op != GGML_OP_MUL_MAT) {
  102. return false;
  103. }
  104. ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
  105. GGML_ASSERT(kernels);
  106. kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
  107. size_t k = op->src[0]->ne[0];
  108. size_t n = op->src[0]->ne[1];
  109. size_t m = op->src[1]->ne[1];
  110. size_t mr = kernel->get_mr();
  111. size_t kr = kernel->get_kr();
  112. size_t sr = kernel->get_sr();
  113. if (kernels->rhs_type == GGML_TYPE_Q4_0) {
  114. size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr);
  115. } else if (kernels->rhs_type == GGML_TYPE_F16) {
  116. size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr) +
  117. variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
  118. k * n * sizeof(float) + n * sizeof(float);
  119. } else {
  120. GGML_ASSERT(false);
  121. }
  122. return true;
  123. }
  124. bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
  125. if (dst->op == GGML_OP_MUL_MAT) {
  126. if (dst->src[0]->type == GGML_TYPE_Q4_0) {
  127. return compute_forward_q4_0(params, dst);
  128. } else if (dst->src[0]->type == GGML_TYPE_F16) {
  129. return compute_forward_kv_cache(params, dst);
  130. }
  131. } else if (dst->op == GGML_OP_GET_ROWS) {
  132. if (dst->src[0]->type == GGML_TYPE_Q4_0) {
  133. return compute_forward_get_rows(params, dst);
  134. }
  135. }
  136. return false;
  137. }
  138. bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) {
  139. static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
  140. const ggml_tensor * src0 = dst->src[0];
  141. const ggml_tensor * src1 = dst->src[1];
  142. GGML_TENSOR_BINARY_OP_LOCALS
  143. ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
  144. GGML_ASSERT(kernels);
  145. kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
  146. GGML_ASSERT(kernel);
  147. const int nth = params->nth;
  148. const int ith = params->ith;
  149. const int64_t lhs_batch_size0 = ne12;
  150. const int64_t rhs_batch_size0 = ne02;
  151. const int64_t batch_size = rhs_batch_size0;
  152. const int64_t r = lhs_batch_size0 / rhs_batch_size0;
  153. const int64_t m = ne11 * r;
  154. const int64_t n = ne01;
  155. const int64_t k = ne00;
  156. const size_t lhs_stride = src1->nb[1];
  157. const size_t rhs_stride = src0->nb[1];
  158. const size_t dst_stride = dst->nb[1];
  159. const int64_t mr = static_cast<int64_t>(kernel->get_mr());
  160. const int64_t nr = static_cast<int64_t>(kernel->get_nr());
  161. const int64_t kr = static_cast<int64_t>(kernel->get_kr());
  162. const int64_t sr = static_cast<int64_t>(kernel->get_sr());
  163. const size_t lhs_packed_size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr);
  164. const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
  165. const size_t kxn_size = k * n * sizeof(float);
  166. const size_t bias_size = n * sizeof(float);
  167. const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
  168. GGML_ASSERT(wsize_required <= params->wsize);
  169. uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
  170. uint8_t * rhs_packed = lhs_packed + lhs_packed_size;
  171. uint8_t * rhs_kxn = rhs_packed + rhs_packed_size;
  172. uint8_t * bias = rhs_kxn + kxn_size;
  173. for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
  174. const uint8_t * lhs_batch = static_cast<const uint8_t *>(src1->data) + batch_idx * m * lhs_stride;
  175. const uint8_t * rhs_batch = static_cast<const uint8_t *>(src0->data) + batch_idx * n * rhs_stride;
  176. uint8_t * dst_batch = static_cast<uint8_t *>(dst->data) + batch_idx * m * dst_stride;
  177. // LHS packing
  178. {
  179. const int64_t m_roundup_mr = kai_roundup(m, mr);
  180. const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth);
  181. if (ith < num_threads) {
  182. const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr);
  183. const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
  184. const int64_t m_start = ith * num_m_per_thread0;
  185. const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
  186. const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
  187. const size_t lhs_packed_offset = variant_call<size_t>(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr);
  188. const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
  189. void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
  190. variant_call<void>(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
  191. }
  192. }
  193. // RHS packing
  194. if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) {
  195. // First thread to reach this point handles RHS packing
  196. memset(bias, 0, n * sizeof(float));
  197. transpose_f32kxn_f16nxk(n, k, reinterpret_cast<float *>(rhs_kxn),
  198. reinterpret_cast<const uint16_t *>(rhs_batch), rhs_stride);
  199. variant_call<void>(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float),
  200. rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
  201. }
  202. ggml_barrier(params->threadpool);
  203. first_to_arrive.clear(std::memory_order_release);
  204. // Perform the matmul
  205. {
  206. const int64_t m_to_process = m;
  207. const int64_t m_start = 0;
  208. const int64_t n_step = static_cast<int64_t>(kernel->get_n_step());
  209. int64_t num_threads = KAI_MIN(n / n_step, nth);
  210. if (num_threads <= 0) {
  211. num_threads = 1;
  212. }
  213. if (ith < num_threads) {
  214. const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step);
  215. const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0;
  216. const int64_t n_start = ith * num_n_per_thread0;
  217. const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
  218. const size_t lhs_packed_offset = variant_call<size_t>(kernel->get_lhs_offset, m_start, k);
  219. const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k);
  220. const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride);
  221. const void * lhs_ptr = lhs_packed + lhs_packed_offset;
  222. const void * rhs_ptr = rhs_packed + rhs_packed_offset;
  223. float * dst_ptr = reinterpret_cast<float *>(dst_batch + dst_offset);
  224. variant_call<void>(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
  225. }
  226. }
  227. if (batch_idx != batch_size - 1) {
  228. // This barrier is necessary when the batch size is larger than 1. While processing a batch,
  229. // the work data buffer (params->wdata) is used as temporary storage which means that only
  230. // a single batch can be processed at any given time. No barrier is needed for the last
  231. // batch since GGML inserts a barrier between the execution of every operator.
  232. ggml_barrier(params->threadpool);
  233. }
  234. }
  235. return true;
  236. }
  237. bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
  238. GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
  239. const ggml_tensor * src0 = dst->src[0];
  240. const ggml_tensor * src1 = dst->src[1];
  241. GGML_TENSOR_BINARY_OP_LOCALS
  242. ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
  243. GGML_ASSERT(kernels);
  244. kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
  245. lhs_packing_info * lhs_info = &kernels->lhs_info;
  246. GGML_ASSERT(kernel);
  247. const int ith = params->ith;
  248. const int nth_raw = params->nth;
  249. const int nth = nth_raw > 0 ? nth_raw : 1;
  250. const size_t k = ne00;
  251. const size_t m = ne11;
  252. const size_t n = ne01;
  253. size_t mr = kernel->get_mr();
  254. size_t kr = kernel->get_kr();
  255. size_t sr = kernel->get_sr();
  256. const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
  257. uint8_t * lhs_packed = (uint8_t*)params->wdata;
  258. const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
  259. const size_t n_step = kernel->get_n_step();
  260. const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
  261. const size_t n_start = ith * num_n_per_thread;
  262. size_t n_to_process = 0;
  263. if (n_start < n) {
  264. n_to_process = num_n_per_thread;
  265. if ((n_start + n_to_process) > n) {
  266. n_to_process = n - n_start;
  267. }
  268. }
  269. // Calculate number of columns to be processed per thread
  270. const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
  271. const size_t m_start = ith * num_m_per_thread;
  272. size_t m_to_process = num_m_per_thread;
  273. if ((m_start + m_to_process) > m) {
  274. m_to_process = m - m_start;
  275. }
  276. if (m_start < m) {
  277. // Transform LHS
  278. const size_t src_stride = src1->nb[1];
  279. const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
  280. const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr);
  281. void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
  282. variant_call<void>(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
  283. }
  284. ggml_barrier(params->threadpool);
  285. // Perform the operation
  286. const size_t dst_stride = dst->nb[1];
  287. const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr);
  288. const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k, QK4_0);
  289. const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
  290. const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
  291. const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
  292. float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
  293. if (n_to_process > 0) {
  294. variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
  295. sizeof(float), -FLT_MAX, FLT_MAX);
  296. }
  297. return true;
  298. }
  299. bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
  300. GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
  301. GGML_ASSERT(ctx.kernels);
  302. const ggml_tensor * src0 = dst->src[0];
  303. const ggml_tensor * src1 = dst->src[1];
  304. GGML_TENSOR_BINARY_OP_LOCALS
  305. rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
  306. kernel_info * kernel = &ctx.kernels->gemm;
  307. const int64_t nc = ne00;
  308. const int64_t nr = ggml_nelements(src1);
  309. const size_t block_rows = kernel->get_nr();
  310. const size_t kr = kernel->get_kr();
  311. const size_t num_bytes_multiplier = sizeof(uint16_t);
  312. const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
  313. const int ith = params->ith;
  314. const int nth = params->nth;
  315. const int dr = (nr + nth - 1) / nth;
  316. const int ir0 = dr * ith;
  317. const int ir1 = MIN(ir0 + dr, nr);
  318. for (int64_t i = ir0; i < ir1; ++i) {
  319. GGML_ASSERT(src1->type == GGML_TYPE_I32);
  320. int64_t row_idx = ((const int32_t *)src1->data)[i];
  321. GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
  322. float *out = (float *)((char *)dst->data + i * nb1);
  323. rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
  324. }
  325. return true;
  326. }
  327. public:
  328. int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
  329. GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
  330. GGML_ASSERT(ctx.kernels);
  331. const size_t n = tensor->ne[1];
  332. const size_t k = tensor->ne[0];
  333. size_t nr = ctx.kernels->gemm.get_nr();
  334. size_t kr = ctx.kernels->gemm.get_kr();
  335. size_t sr = ctx.kernels->gemm.get_sr();
  336. struct kai_rhs_pack_qs4cxs1s0_param params;
  337. params.lhs_zero_point = 1;
  338. params.rhs_zero_point = 8;
  339. variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, &params);
  340. return 0;
  341. GGML_UNUSED(data_size);
  342. }
  343. };
  344. static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) {
  345. static tensor_traits traits;
  346. return &traits;
  347. }
  348. } // namespace ggml::cpu::kleidiai
  349. static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
  350. tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
  351. return GGML_STATUS_SUCCESS;
  352. GGML_UNUSED(buffer);
  353. }
  354. static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
  355. const void * data, size_t offset, size_t size) {
  356. GGML_ASSERT(offset == 0);
  357. GGML_ASSERT(size == ggml_nbytes(tensor));
  358. auto tensor_traits = (ggml::cpu::kleidiai::tensor_traits *) tensor->extra;
  359. auto OK = tensor_traits->repack(tensor, data, size);
  360. GGML_ASSERT(OK == 0);
  361. GGML_UNUSED(buffer);
  362. }
  363. static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
  364. return "CPU_KLEIDIAI";
  365. GGML_UNUSED(buft);
  366. }
  367. static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
  368. ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
  369. if (buffer == nullptr) {
  370. return nullptr;
  371. }
  372. buffer->buft = buft;
  373. buffer->iface.init_tensor = ggml_backend_cpu_kleidiai_buffer_init_tensor;
  374. buffer->iface.set_tensor = ggml_backend_cpu_kleidiai_buffer_set_tensor;
  375. buffer->iface.get_tensor = nullptr;
  376. buffer->iface.cpy_tensor = nullptr;
  377. return buffer;
  378. }
  379. static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
  380. return TENSOR_ALIGNMENT;
  381. GGML_UNUSED(buft);
  382. }
  383. static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
  384. GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
  385. GGML_ASSERT(ctx.kernels);
  386. const size_t n = tensor->ne[1];
  387. const size_t k = tensor->ne[0];
  388. const size_t nr = ctx.kernels->gemm.get_nr();
  389. const size_t kr = ctx.kernels->gemm.get_kr();
  390. return variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
  391. GGML_UNUSED(buft);
  392. }
  393. namespace ggml::cpu::kleidiai {
  394. class extra_buffer_type : ggml::cpu::extra_buffer_type {
  395. bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
  396. if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
  397. op->src[0]->type == GGML_TYPE_Q4_0 &&
  398. op->src[0]->buffer &&
  399. (ggml_n_dims(op->src[0]) == 2) &&
  400. op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
  401. if (op->op == GGML_OP_GET_ROWS && op->src[1]->ne[0] != 8) {
  402. return false;
  403. }
  404. if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
  405. return false;
  406. }
  407. if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) &&
  408. ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
  409. return true;
  410. }
  411. }
  412. return false;
  413. }
  414. ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
  415. if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
  416. if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
  417. return (ggml::cpu::tensor_traits *) op->src[0]->extra;
  418. }
  419. else if (ggml_kleidiai_select_kernels(ctx.features, op) &&
  420. op->src[0]->op == GGML_OP_VIEW &&
  421. (op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) &&
  422. op->src[1]->ne[1] > 1) {
  423. if ((op->src[0]->nb[0] != 2) ||
  424. (op->src[1]->nb[0] != 4) ||
  425. (op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
  426. (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
  427. return nullptr;
  428. }
  429. return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
  430. }
  431. }
  432. return nullptr;
  433. }
  434. };
  435. } // namespace ggml::cpu::kleidiai
  436. ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
  437. static ggml::cpu::kleidiai::extra_buffer_type ctx;
  438. static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_kleidiai = {
  439. /* .iface = */ {
  440. /* .get_name = */ ggml_backend_cpu_kleidiai_buffer_type_get_name,
  441. /* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
  442. /* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
  443. /* .get_max_size = */ nullptr, // defaults to SIZE_MAX
  444. /* .get_alloc_size = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size,
  445. /* .is_host = */ nullptr,
  446. },
  447. /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
  448. /* .context = */ &ctx,
  449. };
  450. init_kleidiai_context();
  451. return &ggml_backend_cpu_buffer_type_kleidiai;
  452. }