ggml-amx.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. #include "ggml-amx.h"
  2. #include "ggml-amx/common.h"
  3. #include "ggml-amx/mmq.h"
  4. #include "ggml-backend-impl.h"
  5. #include "ggml-impl.h"
  6. #if defined(__gnu_linux__)
  7. #include <sys/syscall.h>
  8. #include <unistd.h>
  9. #endif
  10. #include <cstdlib>
  11. #include <cstring>
  12. #include <memory>
  13. #if defined(__AMX_INT8__)
  14. // AMX buffer interface
  15. static void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) {
  16. free(buffer->context);
  17. }
  18. static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) {
  19. return (void *)(buffer->context);
  20. }
  21. static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
  22. memset((char *)tensor->data + offset, value, size);
  23. GGML_UNUSED(buffer);
  24. }
  25. static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
  26. if (qtype_has_amx_kernels(tensor->type)) {
  27. ggml_backend_amx_convert_weight(tensor, data, offset, size);
  28. } else {
  29. memcpy((char *)tensor->data + offset, data, size);
  30. }
  31. GGML_UNUSED(buffer);
  32. }
  33. static void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
  34. GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));
  35. memcpy(data, (const char *)tensor->data + offset, size);
  36. GGML_UNUSED(buffer);
  37. }
  38. static bool ggml_backend_amx_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
  39. if (ggml_backend_buffer_is_host(src->buffer)) {
  40. if (qtype_has_amx_kernels(src->type)) {
  41. ggml_backend_amx_convert_weight(dst, src->data, 0, ggml_backend_amx_get_alloc_size(dst));
  42. } else {
  43. memcpy(dst->data, src->data, ggml_nbytes(src));
  44. }
  45. return true;
  46. }
  47. return false;
  48. GGML_UNUSED(buffer);
  49. }
  50. static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
  51. memset(buffer->context, value, buffer->size);
  52. }
  53. static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {
  54. /* .free_buffer = */ ggml_backend_amx_buffer_free_buffer,
  55. /* .get_base = */ ggml_backend_amx_buffer_get_base,
  56. /* .init_tensor = */ NULL, // no initialization required
  57. /* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor,
  58. /* .set_tensor = */ ggml_backend_amx_buffer_set_tensor,
  59. /* .get_tensor = */ ggml_backend_amx_buffer_get_tensor,
  60. /* .cpy_tensor = */ ggml_backend_amx_buffer_cpy_tensor,
  61. /* .clear = */ ggml_backend_amx_buffer_clear,
  62. /* .reset = */ NULL,
  63. };
  64. static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
  65. return "AMX";
  66. GGML_UNUSED(buft);
  67. }
  68. static ggml_backend_buffer_t ggml_backend_amx_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
  69. void * data = aligned_alloc(TENSOR_ALIGNMENT, size);
  70. if (data == NULL) {
  71. fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
  72. return NULL;
  73. }
  74. return ggml_backend_buffer_init(buft, ggml_backend_amx_buffer_interface, data, size);
  75. }
  76. static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
  77. return TENSOR_ALIGNMENT;
  78. GGML_UNUSED(buft);
  79. }
  80. static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
  81. return ggml_backend_amx_get_alloc_size(tensor);
  82. GGML_UNUSED(buft);
  83. }
  84. static bool ggml_backend_amx_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
  85. return false;
  86. GGML_UNUSED(buft);
  87. }
  88. ggml_backend_buffer_type_t ggml_backend_amx_buffer_type() {
  89. static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = {
  90. /* .iface = */ {
  91. /* .get_name = */ ggml_backend_amx_buffer_type_get_name,
  92. /* .alloc_buffer = */ ggml_backend_amx_buffer_type_alloc_buffer,
  93. /* .get_alignment = */ ggml_backend_amx_buffer_type_get_alignment,
  94. /* .get_max_size = */ NULL, // defaults to SIZE_MAX
  95. /* .get_alloc_size = */ ggml_backend_amx_buffer_type_get_alloc_size,
  96. /* .is_host = */ ggml_backend_amx_buffer_type_is_host,
  97. },
  98. /* .device = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0),
  99. /* .context = */ NULL,
  100. };
  101. return &ggml_backend_buffer_type_amx;
  102. }
  103. // backend interface
  104. static const char * ggml_backend_amx_name(ggml_backend_t backend) {
  105. return "AMX";
  106. GGML_UNUSED(backend);
  107. }
  108. static void ggml_backend_amx_free(ggml_backend_t backend) {
  109. ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context;
  110. delete ctx;
  111. delete backend;
  112. }
  113. static enum ggml_status ggml_backend_amx_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
  114. ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context;
  115. for (int i = 0; i < cgraph->n_nodes; i++) {
  116. struct ggml_tensor * node = cgraph->nodes[i];
  117. switch (node->op) {
  118. case GGML_OP_MUL_MAT:
  119. ggml_backend_amx_mul_mat(ctx, node);
  120. break;
  121. case GGML_OP_NONE:
  122. case GGML_OP_RESHAPE:
  123. case GGML_OP_VIEW:
  124. case GGML_OP_PERMUTE:
  125. case GGML_OP_TRANSPOSE:
  126. break;
  127. default:
  128. fprintf(stderr, "%s: unsupported op %s\n", __func__, ggml_op_desc(node));
  129. GGML_ASSERT(false);
  130. }
  131. }
  132. return GGML_STATUS_SUCCESS;
  133. GGML_UNUSED(backend);
  134. }
  135. static struct ggml_backend_i ggml_backend_amx_i = {
  136. /* .get_name = */ ggml_backend_amx_name,
  137. /* .free = */ ggml_backend_amx_free,
  138. /* .set_tensor_async = */ NULL,
  139. /* .get_tensor_async = */ NULL,
  140. /* .cpy_tensor_async = */ NULL,
  141. /* .synchronize = */ NULL,
  142. /* .graph_plan_create = */ NULL,
  143. /* .graph_plan_free = */ NULL,
  144. /* .graph_plan_update = */ NULL,
  145. /* .graph_plan_compute = */ NULL,
  146. /* .graph_compute = */ ggml_backend_amx_graph_compute,
  147. /* .event_record = */ NULL,
  148. /* .event_wait = */ NULL,
  149. };
  150. static ggml_guid_t ggml_backend_amx_guid() {
  151. static ggml_guid guid = { 0x13, 0xb8, 0xa4, 0xc4, 0xba, 0xfe, 0x51, 0x67, 0x87, 0x44, 0x55, 0x15, 0xb2, 0x35, 0x62, 0x3e };
  152. return &guid;
  153. }
  154. #define ARCH_GET_XCOMP_PERM 0x1022
  155. #define ARCH_REQ_XCOMP_PERM 0x1023
  156. #define XFEATURE_XTILECFG 17
  157. #define XFEATURE_XTILEDATA 18
  158. static bool ggml_amx_init() {
  159. #if defined(__gnu_linux__)
  160. if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
  161. fprintf(stderr, "AMX is not ready to be used!\n");
  162. return false;
  163. }
  164. return true;
  165. #elif defined(_WIN32)
  166. return true;
  167. #endif
  168. }
  169. ggml_backend_t ggml_backend_amx_init() {
  170. // invoke a Linux system call to request access to AMX features
  171. ggml_amx_init();
  172. // backend context
  173. ggml_backend_amx_context * ctx = new ggml_backend_amx_context;
  174. // ggml amx backend
  175. ggml_backend_t backend = new ggml_backend {
  176. /* .guid = */ ggml_backend_amx_guid(),
  177. /* .interface = */ ggml_backend_amx_i,
  178. /* .device = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0),
  179. /* .context = */ ctx,
  180. };
  181. return backend;
  182. }
  183. bool ggml_backend_is_amx(ggml_backend_t backend) {
  184. return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_amx_guid());
  185. }
  186. void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
  187. GGML_ASSERT(ggml_backend_is_amx(backend_amx));
  188. ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend_amx->context;
  189. ctx->n_threads = n_threads;
  190. }
  191. // device interface
  192. static const char * ggml_backend_amx_device_get_name(ggml_backend_dev_t dev) {
  193. return "AMX";
  194. GGML_UNUSED(dev);
  195. }
  196. static const char * ggml_backend_amx_device_get_description(ggml_backend_dev_t dev) {
  197. return "Intel Advanced Matrix Extensions";
  198. GGML_UNUSED(dev);
  199. }
  200. static void ggml_backend_amx_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
  201. // TODO
  202. *free = 0;
  203. *total = 0;
  204. GGML_UNUSED(dev);
  205. }
  206. static enum ggml_backend_dev_type ggml_backend_amx_device_get_type(ggml_backend_dev_t dev) {
  207. return GGML_BACKEND_DEVICE_TYPE_ACCEL;
  208. GGML_UNUSED(dev);
  209. }
  210. static void ggml_backend_amx_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
  211. props->name = ggml_backend_amx_device_get_name(dev);
  212. props->description = ggml_backend_amx_device_get_description(dev);
  213. props->type = ggml_backend_amx_device_get_type(dev);
  214. ggml_backend_amx_device_get_memory(dev, &props->memory_free, &props->memory_total);
  215. // `buffer_from_host_ptr` is intended to be used in mmap, when memory layout unchanged
  216. props->caps = {
  217. /* .async = */ false,
  218. /* .host_buffer = */ false,
  219. /* .buffer_from_host_ptr = */ false,
  220. /* .events = */ false,
  221. };
  222. }
  223. static ggml_backend_t ggml_backend_amx_device_init(ggml_backend_dev_t dev, const char * params) {
  224. return ggml_backend_amx_init();
  225. GGML_UNUSED(dev);
  226. GGML_UNUSED(params);
  227. }
  228. static ggml_backend_buffer_type_t ggml_backend_amx_device_get_buffer_type(ggml_backend_dev_t dev) {
  229. return ggml_backend_amx_buffer_type();
  230. GGML_UNUSED(dev);
  231. }
  232. static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
  233. // handle only 2d gemm for now
  234. auto is_contiguous_2d = [](const struct ggml_tensor * t) {
  235. return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
  236. };
  237. switch (op->op) {
  238. case GGML_OP_NONE:
  239. case GGML_OP_RESHAPE:
  240. case GGML_OP_VIEW:
  241. case GGML_OP_PERMUTE:
  242. case GGML_OP_TRANSPOSE:
  243. return true;
  244. case GGML_OP_MUL_MAT: {
  245. const struct ggml_tensor * src0 = op->src[0];
  246. const struct ggml_tensor * src1 = op->src[1];
  247. const enum ggml_type type = src0->type;
  248. const int64_t ne0 = op->ne[0];
  249. bool is_training = src0->grad || src1->grad;
  250. // amx kernels enables for Q4_0, Q4_1, Q8_0, F16
  251. // Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
  252. bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);
  253. bool can_use_amx =
  254. is_contiguous_2d(src0) && // src0 must be contiguous
  255. is_contiguous_2d(src1) && // src1 must be contiguous
  256. !is_training && // inference only
  257. src1->type == GGML_TYPE_F32 && // src1 must be float32
  258. has_amx_kernels && // with amx kernel impls
  259. ne0 % (TILE_N * 2) == 0; // out_features is 32x
  260. return can_use_amx;
  261. }
  262. default:
  263. return false;
  264. }
  265. GGML_UNUSED(dev);
  266. }
  267. static bool ggml_backend_amx_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
  268. return buft->iface.get_name == ggml_backend_amx_buffer_type_get_name;
  269. GGML_UNUSED(dev);
  270. }
  271. static const struct ggml_backend_device_i ggml_backend_amx_device_i = {
  272. /* .get_name = */ ggml_backend_amx_device_get_name,
  273. /* .get_description = */ ggml_backend_amx_device_get_description,
  274. /* .get_memory = */ ggml_backend_amx_device_get_memory,
  275. /* .get_type = */ ggml_backend_amx_device_get_type,
  276. /* .get_props = */ ggml_backend_amx_device_get_props,
  277. /* .init_backend = */ ggml_backend_amx_device_init,
  278. /* .get_buffer_type = */ ggml_backend_amx_device_get_buffer_type,
  279. /* .get_host_buffer_type = */ NULL,
  280. /* .buffer_from_host_ptr = */ NULL,
  281. /* .supports_op = */ ggml_backend_amx_device_supports_op,
  282. /* .supports_buft = */ ggml_backend_amx_device_supports_buft,
  283. /* .offload_op = */ NULL,
  284. /* .event_new = */ NULL,
  285. /* .event_free = */ NULL,
  286. /* .event_synchronize = */ NULL,
  287. };
  288. // backend reg interface
  289. static const char * ggml_backend_amx_reg_get_name(ggml_backend_reg_t reg) {
  290. return "AMX";
  291. GGML_UNUSED(reg);
  292. }
  293. static size_t ggml_backend_amx_reg_get_device_count(ggml_backend_reg_t reg) {
  294. return 1;
  295. GGML_UNUSED(reg);
  296. }
  297. static ggml_backend_dev_t ggml_backend_amx_reg_get_device(ggml_backend_reg_t reg, size_t index) {
  298. GGML_ASSERT(index == 0);
  299. static ggml_backend_device ggml_backend_amx_device = {
  300. /* .iface = */ ggml_backend_amx_device_i,
  301. /* .reg = */ reg,
  302. /* .context = */ nullptr,
  303. };
  304. return &ggml_backend_amx_device;
  305. GGML_UNUSED(reg);
  306. GGML_UNUSED(index);
  307. }
  308. static void * ggml_backend_amx_get_proc_address(ggml_backend_reg_t reg, const char * name) {
  309. if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
  310. return (void *)ggml_backend_amx_set_n_threads;
  311. }
  312. return NULL;
  313. GGML_UNUSED(reg);
  314. GGML_UNUSED(name);
  315. }
  316. static const struct ggml_backend_reg_i ggml_backend_amx_reg_i = {
  317. /* .get_name = */ ggml_backend_amx_reg_get_name,
  318. /* .get_device_count = */ ggml_backend_amx_reg_get_device_count,
  319. /* .get_device = */ ggml_backend_amx_reg_get_device,
  320. /* .get_proc_address = */ ggml_backend_amx_get_proc_address,
  321. };
  322. ggml_backend_reg_t ggml_backend_amx_reg(void) {
  323. static struct ggml_backend_reg ggml_backend_amx_reg = {
  324. /* .iface = */ ggml_backend_amx_reg_i,
  325. /* .context = */ NULL,
  326. };
  327. return &ggml_backend_amx_reg;
  328. }
  329. #else // if defined(__AMX_INT8__)
  330. ggml_backend_t ggml_backend_amx_init(void) {
  331. fprintf(stderr, "GGML is not compiled with AMX support!\n");
  332. return ggml_backend_t{};
  333. }
  334. void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
  335. fprintf(stderr, "GGML is not compiled with AMX support!\n");
  336. GGML_UNUSED(backend_amx);
  337. GGML_UNUSED(n_threads);
  338. }
  339. #endif