ggml-metal-device.cpp 55 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743
  1. #include "ggml-metal-device.h"
  2. #include "ggml-metal-impl.h"
  3. #include "ggml-impl.h"
  4. #include <cassert>
  5. #include <memory>
  6. #include <string>
  7. #include <unordered_map>
  8. struct ggml_metal_device_deleter {
  9. void operator()(ggml_metal_device_t ctx) {
  10. ggml_metal_device_free(ctx);
  11. }
  12. };
  13. typedef std::unique_ptr<ggml_metal_device, ggml_metal_device_deleter> ggml_metal_device_ptr;
  14. ggml_metal_device_t ggml_metal_device_get(void) {
  15. static ggml_metal_device_ptr ctx { ggml_metal_device_init() };
  16. return ctx.get();
  17. }
  18. struct ggml_metal_pipelines {
  19. std::unordered_map<std::string, ggml_metal_pipeline_t> data;
  20. };
  21. ggml_metal_pipelines_t ggml_metal_pipelines_init(void) {
  22. ggml_metal_pipelines_t res = new ggml_metal_pipelines();
  23. return res;
  24. }
  25. void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls) {
  26. if (!ppls) {
  27. return;
  28. }
  29. for (auto it = ppls->data.begin(); it != ppls->data.end(); ++it) {
  30. ggml_metal_pipeline_free(it->second);
  31. }
  32. delete ppls;
  33. }
  34. void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline) {
  35. ppls->data[name] = pipeline;
  36. }
  37. ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) {
  38. if (ppls->data.find(name) == ppls->data.end()) {
  39. return nullptr;
  40. }
  41. return ppls->data[name];
  42. }
  43. struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) {
  44. char base[256];
  45. char name[256];
  46. const char * op_str = "undefined";
  47. switch (op) {
  48. case GGML_OP_ADD_ID: op_str = "add_id"; break;
  49. case GGML_OP_CONCAT: op_str = "concat"; break;
  50. default: GGML_ABORT("fatal error");
  51. };
  52. snprintf(base, 256, "kernel_%s", op_str);
  53. snprintf(name, 256, "%s", base);
  54. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  55. if (!res.pipeline) {
  56. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  57. }
  58. return res;
  59. }
  60. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) {
  61. char base[256];
  62. char name[256];
  63. snprintf(base, 256, "kernel_cpy_%s_%s", ggml_type_name(tsrc), ggml_type_name(tdst));
  64. snprintf(name, 256, "%s", base);
  65. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  66. if (!res.pipeline) {
  67. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  68. }
  69. return res;
  70. }
  71. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
  72. GGML_ASSERT(ggml_is_contiguous(op->src[0]));
  73. GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
  74. const char * pool_str = "undefined";
  75. switch (op_pool) {
  76. case GGML_OP_POOL_AVG: pool_str = "avg"; break;
  77. case GGML_OP_POOL_MAX: pool_str = "max"; break;
  78. default: GGML_ASSERT(false && "not implemented");
  79. };
  80. char base[256];
  81. char name[256];
  82. snprintf(base, 256, "kernel_pool_2d_%s_%s", pool_str, ggml_type_name(op->src[0]->type));
  83. snprintf(name, 256, "%s", base);
  84. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  85. if (!res.pipeline) {
  86. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  87. }
  88. return res;
  89. }
  90. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) {
  91. char base[256];
  92. char name[256];
  93. snprintf(base, 256, "kernel_get_rows_%s", ggml_type_name(tsrc));
  94. snprintf(name, 256, "%s", base);
  95. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  96. if (!res.pipeline) {
  97. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  98. }
  99. return res;
  100. }
  101. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) {
  102. char base[256];
  103. char name[256];
  104. snprintf(base, 256, "kernel_set_rows_%s_%s", ggml_type_name(tdst), ggml_type_name(tidx));
  105. snprintf(name, 256, "%s", base);
  106. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  107. if (!res.pipeline) {
  108. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  109. }
  110. return res;
  111. }
  112. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) {
  113. char base[256];
  114. char name[256];
  115. snprintf(base, 256, "kernel_repeat_%s", ggml_type_name(tsrc));
  116. snprintf(name, 256, "%s", base);
  117. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  118. if (!res.pipeline) {
  119. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  120. }
  121. return res;
  122. }
  123. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
  124. GGML_ASSERT(ggml_is_contiguous(op->src[0]));
  125. char base[256];
  126. char name[256];
  127. const int64_t n = ggml_nelements(op);
  128. const char * op_str = "undefined";
  129. switch (op->op) {
  130. case GGML_OP_SCALE: op_str = "scale"; break;
  131. case GGML_OP_FILL: op_str = "fill"; break;
  132. case GGML_OP_CLAMP: op_str = "clamp"; break;
  133. case GGML_OP_SQR: op_str = "sqr"; break;
  134. case GGML_OP_SQRT: op_str = "sqrt"; break;
  135. case GGML_OP_SIN: op_str = "sin"; break;
  136. case GGML_OP_COS: op_str = "cos"; break;
  137. case GGML_OP_LOG: op_str = "log"; break;
  138. case GGML_OP_LEAKY_RELU: op_str = "leaky_relu"; break;
  139. case GGML_OP_UNARY:
  140. switch (ggml_get_unary_op(op)) {
  141. case GGML_UNARY_OP_TANH: op_str = "tanh"; break;
  142. case GGML_UNARY_OP_RELU: op_str = "relu"; break;
  143. case GGML_UNARY_OP_SIGMOID: op_str = "sigmoid"; break;
  144. case GGML_UNARY_OP_GELU: op_str = "gelu"; break;
  145. case GGML_UNARY_OP_GELU_ERF: op_str = "gelu_erf"; break;
  146. case GGML_UNARY_OP_GELU_QUICK: op_str = "gelu_quick"; break;
  147. case GGML_UNARY_OP_SILU: op_str = "silu"; break;
  148. case GGML_UNARY_OP_ELU: op_str = "elu"; break;
  149. case GGML_UNARY_OP_NEG: op_str = "neg"; break;
  150. case GGML_UNARY_OP_ABS: op_str = "abs"; break;
  151. case GGML_UNARY_OP_SGN: op_str = "sgn"; break;
  152. case GGML_UNARY_OP_STEP: op_str = "step"; break;
  153. case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break;
  154. case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break;
  155. case GGML_UNARY_OP_EXP: op_str = "exp"; break;
  156. case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break;
  157. case GGML_UNARY_OP_EXPM1: op_str = "expm1"; break;
  158. default: GGML_ABORT("fatal error");
  159. } break;
  160. default: GGML_ABORT("fatal error");
  161. };
  162. const char * suffix = "";
  163. if (n % 4 == 0) {
  164. suffix = "_4";
  165. }
  166. snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix);
  167. snprintf(name, 256, "%s", base);
  168. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  169. if (!res.pipeline) {
  170. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  171. }
  172. return res;
  173. }
  174. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) {
  175. GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
  176. char base[256];
  177. char name[256];
  178. const char * op_str = "undefined";
  179. switch (op->op) {
  180. case GGML_OP_GLU:
  181. switch (ggml_get_glu_op(op)) {
  182. case GGML_GLU_OP_REGLU: op_str = "reglu"; break;
  183. case GGML_GLU_OP_GEGLU: op_str = "geglu"; break;
  184. case GGML_GLU_OP_SWIGLU: op_str = "swiglu"; break;
  185. case GGML_GLU_OP_SWIGLU_OAI: op_str = "swiglu_oai"; break;
  186. case GGML_GLU_OP_GEGLU_ERF: op_str = "geglu_erf"; break;
  187. case GGML_GLU_OP_GEGLU_QUICK: op_str = "geglu_quick"; break;
  188. default: GGML_ABORT("fatal error");
  189. } break;
  190. default: GGML_ABORT("fatal error");
  191. };
  192. snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
  193. snprintf(name, 256, "%s", base);
  194. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  195. if (!res.pipeline) {
  196. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  197. }
  198. return res;
  199. }
  200. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
  201. assert(op->op == GGML_OP_SUM);
  202. char base[256];
  203. char name[256];
  204. snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type));
  205. snprintf(name, 256, "%s", base);
  206. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  207. if (!res.pipeline) {
  208. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  209. }
  210. return res;
  211. }
  212. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
  213. GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
  214. char base[256];
  215. char name[256];
  216. const char * op_str = "undefined";
  217. switch (op->op) {
  218. case GGML_OP_SUM_ROWS:
  219. op_str = "sum_rows"; break;
  220. case GGML_OP_MEAN:
  221. op_str = "mean"; break;
  222. default: GGML_ABORT("fatal error");
  223. };
  224. snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
  225. snprintf(name, 256, "%s", base);
  226. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  227. if (!res.pipeline) {
  228. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  229. }
  230. res.smem = 32*sizeof(float);
  231. return res;
  232. }
  233. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) {
  234. GGML_ASSERT(op->op == GGML_OP_CUMSUM);
  235. char base[256];
  236. char name[256];
  237. snprintf(base, 256, "kernel_cumsum_blk_%s", ggml_type_name(op->src[0]->type));
  238. snprintf(name, 256, "%s", base);
  239. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  240. if (!res.pipeline) {
  241. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  242. }
  243. return res;
  244. }
  245. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) {
  246. GGML_ASSERT(op->op == GGML_OP_CUMSUM);
  247. char base[256];
  248. char name[256];
  249. snprintf(base, 256, "kernel_cumsum_add_%s", ggml_type_name(op->src[0]->type));
  250. snprintf(name, 256, "%s", base);
  251. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  252. if (!res.pipeline) {
  253. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  254. }
  255. return res;
  256. }
  257. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
  258. GGML_ASSERT(op->op == GGML_OP_TRI);
  259. GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
  260. char base[256];
  261. char name[256];
  262. const char * op_str = "tri";
  263. const int ttype = op->op_params[0];
  264. snprintf(base, 256, "kernel_%s_%s_%d", op_str, ggml_type_name(op->src[0]->type), ttype);
  265. snprintf(name, 256, "%s", base);
  266. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  267. if (!res.pipeline) {
  268. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  269. }
  270. return res;
  271. }
  272. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {
  273. GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32);
  274. char base[256];
  275. char name[256];
  276. const char * suffix = "";
  277. if (op->src[0]->ne[0] % 4 == 0) {
  278. suffix = "_4";
  279. }
  280. const ggml_type tsrc1 = op->src[1] ? op->src[1]->type : GGML_TYPE_F32;
  281. snprintf(base, 256, "kernel_soft_max_%s%s", ggml_type_name(tsrc1), suffix);
  282. snprintf(name, 256, "%s", base);
  283. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  284. if (!res.pipeline) {
  285. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  286. }
  287. res.smem = 32*sizeof(float);
  288. return res;
  289. }
  290. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) {
  291. GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
  292. GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
  293. GGML_ASSERT(ggml_is_contiguous(op->src[0]));
  294. GGML_ASSERT(ggml_is_contiguous(op->src[1]));
  295. char base[256];
  296. char name[256];
  297. const char * suffix = "";
  298. if (op->src[1]->ne[0] % 4 == 0) {
  299. suffix = "_4";
  300. }
  301. snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
  302. snprintf(name, 256, "%s", base);
  303. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  304. if (!res.pipeline) {
  305. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  306. }
  307. return res;
  308. }
  309. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched(ggml_metal_library_t lib, const ggml_tensor * op, int ssm_conv_bs) {
  310. GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
  311. GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
  312. GGML_ASSERT(ggml_is_contiguous(op->src[0]));
  313. GGML_ASSERT(ggml_is_contiguous(op->src[1]));
  314. char base[256];
  315. char name[256];
  316. const char * suffix = "";
  317. if (op->src[1]->ne[0] % 4 == 0) {
  318. suffix = "_4";
  319. }
  320. snprintf(base, 256, "kernel_ssm_conv_%s_%s_batched%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
  321. snprintf(name, 256, "%s_ssm_conv_bs=%d", base, ssm_conv_bs);
  322. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  323. if (!res.pipeline) {
  324. ggml_metal_cv_t cv = ggml_metal_cv_init();
  325. ggml_metal_cv_set_int16(cv, ssm_conv_bs, FC_SSM_CONV + 0);
  326. res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
  327. ggml_metal_cv_free(cv);
  328. }
  329. return res;
  330. }
  331. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
  332. GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
  333. char base[256];
  334. char name[256];
  335. const int nsg = (ne00 + 31)/32;
  336. snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
  337. snprintf(name, 256, "%s_nsg=%d", base, nsg);
  338. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  339. if (!res.pipeline) {
  340. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  341. }
  342. // Shared memory layout:
  343. // - sgptg * NW floats for partial sums (nsg * 32)
  344. // - sgptg floats for shared_x_dt (nsg)
  345. // - sgptg floats for shared_dA (nsg)
  346. // Total: nsg * (32 + 2) floats
  347. res.smem = (32 + 2)*sizeof(float)*nsg;
  348. return res;
  349. }
  350. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) {
  351. char base[256];
  352. char name[256];
  353. const int64_t C = op->ne[0];
  354. const int64_t H = op->src[0]->ne[1];
  355. switch (op->op) {
  356. case GGML_OP_RWKV_WKV6:
  357. {
  358. GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32);
  359. GGML_ASSERT(C % H == 0);
  360. GGML_ASSERT(C / H == 64);
  361. snprintf(base, 256, "kernel_rwkv_wkv6_%s", ggml_type_name(op->src[0]->type));
  362. } break;
  363. case GGML_OP_RWKV_WKV7:
  364. {
  365. GGML_ASSERT(op->src[6]->type == GGML_TYPE_F32);
  366. GGML_ASSERT(C % H == 0);
  367. GGML_ASSERT(C / H == 64);
  368. snprintf(base, 256, "kernel_rwkv_wkv7_%s", ggml_type_name(op->src[0]->type));
  369. } break;
  370. default:
  371. GGML_ABORT("fatal error");
  372. }
  373. snprintf(name, 256, "%s", base);
  374. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  375. if (!res.pipeline) {
  376. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  377. }
  378. return res;
  379. }
  380. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
  381. char base[256];
  382. char name[256];
  383. snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
  384. snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
  385. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  386. if (!res.pipeline) {
  387. ggml_metal_cv_t cv = ggml_metal_cv_init();
  388. ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
  389. ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
  390. res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
  391. ggml_metal_cv_free(cv);
  392. }
  393. return res;
  394. }
  395. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) {
  396. char base[256];
  397. char name[256];
  398. const ggml_type tsrc0 = op->src[0]->type;
  399. const ggml_type tsrc1 = op->src[1]->type;
  400. const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
  401. const bool bc_out = op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0;
  402. snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
  403. snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
  404. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  405. if (!res.pipeline) {
  406. ggml_metal_cv_t cv = ggml_metal_cv_init();
  407. ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
  408. ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
  409. res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
  410. ggml_metal_cv_free(cv);
  411. }
  412. // when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes
  413. res.smem = bc_out ? 8192 : 4096 + 2048;
  414. return res;
  415. }
  416. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) {
  417. GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
  418. GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
  419. char base[256];
  420. char name[256];
  421. int nsg = 0; // number of simdgroups
  422. int nr0 = 0; // number of src0 rows per simdgroup
  423. int nr1 = 1; // number of src1 rows per threadgroup
  424. size_t smem = 0; // shared memory
  425. const ggml_type tsrc0 = op->src[0]->type;
  426. const ggml_type tsrc1 = op->src[1]->type;
  427. const char * suffix = "";
  428. // use custom matrix x vector kernel
  429. switch (tsrc0) {
  430. case GGML_TYPE_F32:
  431. case GGML_TYPE_F16:
  432. case GGML_TYPE_BF16:
  433. {
  434. if (ne00 < 32) {
  435. nsg = 1;
  436. nr0 = 32;
  437. nr1 = 1;
  438. suffix = "_short";
  439. } else {
  440. nsg = std::min(4, (ne00 + 127) / 128);
  441. nr0 = 2;
  442. nr1 = 1;
  443. smem = 32*sizeof(float)*nr0;
  444. suffix = ne00 % 4 == 0 ? "_4" : "";
  445. }
  446. } break;
  447. case GGML_TYPE_Q4_0:
  448. {
  449. nsg = N_SG_Q4_0;
  450. nr0 = N_R0_Q4_0;
  451. } break;
  452. case GGML_TYPE_Q4_1:
  453. {
  454. nsg = N_SG_Q4_1;
  455. nr0 = N_R0_Q4_1;
  456. } break;
  457. case GGML_TYPE_Q5_0:
  458. {
  459. nsg = N_SG_Q5_0;
  460. nr0 = N_R0_Q5_0;
  461. } break;
  462. case GGML_TYPE_Q5_1:
  463. {
  464. nsg = N_SG_Q5_1;
  465. nr0 = N_R0_Q5_1;
  466. } break;
  467. case GGML_TYPE_Q8_0:
  468. {
  469. nsg = N_SG_Q8_0;
  470. nr0 = N_R0_Q8_0;
  471. smem = 32*sizeof(float)*N_R0_Q8_0;
  472. } break;
  473. case GGML_TYPE_MXFP4:
  474. {
  475. nsg = N_SG_MXFP4;
  476. nr0 = N_R0_MXFP4;
  477. smem = 32*sizeof(float);
  478. } break;
  479. case GGML_TYPE_Q2_K:
  480. {
  481. nsg = N_SG_Q2_K;
  482. nr0 = N_R0_Q2_K;
  483. } break;
  484. case GGML_TYPE_Q3_K:
  485. {
  486. nsg = N_SG_Q3_K;
  487. nr0 = N_R0_Q3_K;
  488. } break;
  489. case GGML_TYPE_Q4_K:
  490. {
  491. nsg = N_SG_Q4_K;
  492. nr0 = N_R0_Q4_K;
  493. } break;
  494. case GGML_TYPE_Q5_K:
  495. {
  496. nsg = N_SG_Q5_K;
  497. nr0 = N_R0_Q5_K;
  498. } break;
  499. case GGML_TYPE_Q6_K:
  500. {
  501. nsg = N_SG_Q6_K;
  502. nr0 = N_R0_Q6_K;
  503. } break;
  504. case GGML_TYPE_IQ2_XXS:
  505. {
  506. nsg = N_SG_IQ2_XXS;
  507. nr0 = N_R0_IQ2_XXS;
  508. smem = 256*8+128;
  509. } break;
  510. case GGML_TYPE_IQ2_XS:
  511. {
  512. nsg = N_SG_IQ2_XS;
  513. nr0 = N_R0_IQ2_XS;
  514. smem = 512*8+128;
  515. } break;
  516. case GGML_TYPE_IQ3_XXS:
  517. {
  518. nsg = N_SG_IQ3_XXS;
  519. nr0 = N_R0_IQ3_XXS;
  520. smem = 256*4+128;
  521. } break;
  522. case GGML_TYPE_IQ3_S:
  523. {
  524. nsg = N_SG_IQ3_S;
  525. nr0 = N_R0_IQ3_S;
  526. smem = 512*4;
  527. } break;
  528. case GGML_TYPE_IQ2_S:
  529. {
  530. nsg = N_SG_IQ2_S;
  531. nr0 = N_R0_IQ2_S;
  532. } break;
  533. case GGML_TYPE_IQ1_S:
  534. {
  535. nsg = N_SG_IQ1_S;
  536. nr0 = N_R0_IQ1_S;
  537. } break;
  538. case GGML_TYPE_IQ1_M:
  539. {
  540. nsg = N_SG_IQ1_M;
  541. nr0 = N_R0_IQ1_M;
  542. } break;
  543. case GGML_TYPE_IQ4_NL:
  544. {
  545. nsg = N_SG_IQ4_NL;
  546. nr0 = N_R0_IQ4_NL;
  547. smem = 32*sizeof(float);
  548. } break;
  549. case GGML_TYPE_IQ4_XS:
  550. {
  551. nsg = N_SG_IQ4_XS;
  552. nr0 = N_R0_IQ4_XS;
  553. smem = 32*sizeof(float);
  554. } break;
  555. default:
  556. {
  557. GGML_LOG_ERROR("Asserting on type %d\n", (int) tsrc0);
  558. GGML_ABORT("not implemented");
  559. }
  560. };
  561. snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
  562. snprintf(name, 256, "%s_nsg=%d", base, nsg);
  563. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  564. if (!res.pipeline) {
  565. ggml_metal_cv_t cv = ggml_metal_cv_init();
  566. ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
  567. res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
  568. ggml_metal_cv_free(cv);
  569. }
  570. res.nr0 = nr0;
  571. res.nr1 = nr1;
  572. res.nsg = nsg;
  573. res.smem = smem;
  574. return res;
  575. }
  576. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) {
  577. char base[256];
  578. char name[256];
  579. snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
  580. snprintf(name, 256, "%s_ne02=%d", base, ne02);
  581. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  582. if (!res.pipeline) {
  583. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  584. }
  585. res.smem = (size_t) ne02*ne20*sizeof(uint16_t);
  586. return res;
  587. }
  588. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) {
  589. char base[256];
  590. char name[256];
  591. const ggml_type tsrc0 = op->src[0]->type;
  592. const ggml_type tsrc1 = op->src[1]->type;
  593. const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
  594. snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
  595. snprintf(name, 256, "%s_bci=%d", base, bc_inp);
  596. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  597. if (!res.pipeline) {
  598. ggml_metal_cv_t cv = ggml_metal_cv_init();
  599. ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
  600. res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
  601. ggml_metal_cv_free(cv);
  602. }
  603. res.smem = 8192;
  604. return res;
  605. }
  606. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) {
  607. GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
  608. GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
  609. char base[256];
  610. char name[256];
  611. int nsg = 0; // number of simdgroups
  612. int nr0 = 0; // number of src0 rows per simdgroup
  613. int nr1 = 1; // number of src1 rows per threadgroup
  614. size_t smem = 0; // shared memory
  615. const ggml_type tsrc0 = op->src[0]->type;
  616. const ggml_type tsrc1 = op->src[1]->type;
  617. const char * suffix = "";
  618. // use custom matrix x vector kernel
  619. switch (tsrc0) {
  620. case GGML_TYPE_F32:
  621. case GGML_TYPE_F16:
  622. case GGML_TYPE_BF16:
  623. {
  624. nsg = std::min(4, (ne00 + 127) / 128);
  625. nr0 = 2;
  626. nr1 = 1;
  627. smem = 32*sizeof(float)*nr0;
  628. suffix = ne00 % 4 == 0 ? "_4" : "";
  629. } break;
  630. case GGML_TYPE_Q4_0:
  631. {
  632. nsg = N_SG_Q4_0;
  633. nr0 = N_R0_Q4_0;
  634. } break;
  635. case GGML_TYPE_Q4_1:
  636. {
  637. nsg = N_SG_Q4_1;
  638. nr0 = N_R0_Q4_1;
  639. } break;
  640. case GGML_TYPE_Q5_0:
  641. {
  642. nsg = N_SG_Q5_0;
  643. nr0 = N_R0_Q5_0;
  644. } break;
  645. case GGML_TYPE_Q5_1:
  646. {
  647. nsg = N_SG_Q5_1;
  648. nr0 = N_R0_Q5_1;
  649. } break;
  650. case GGML_TYPE_Q8_0:
  651. {
  652. nsg = N_SG_Q8_0;
  653. nr0 = N_R0_Q8_0;
  654. smem = 32*sizeof(float)*N_R0_Q8_0;
  655. } break;
  656. case GGML_TYPE_MXFP4:
  657. {
  658. nsg = N_SG_MXFP4;
  659. nr0 = N_R0_MXFP4;
  660. smem = 32*sizeof(float);
  661. } break;
  662. case GGML_TYPE_Q2_K:
  663. {
  664. nsg = N_SG_Q2_K;
  665. nr0 = N_R0_Q2_K;
  666. } break;
  667. case GGML_TYPE_Q3_K:
  668. {
  669. nsg = N_SG_Q3_K;
  670. nr0 = N_R0_Q3_K;
  671. } break;
  672. case GGML_TYPE_Q4_K:
  673. {
  674. nsg = N_SG_Q4_K;
  675. nr0 = N_R0_Q4_K;
  676. } break;
  677. case GGML_TYPE_Q5_K:
  678. {
  679. nsg = N_SG_Q5_K;
  680. nr0 = N_R0_Q5_K;
  681. } break;
  682. case GGML_TYPE_Q6_K:
  683. {
  684. nsg = N_SG_Q6_K;
  685. nr0 = N_R0_Q6_K;
  686. } break;
  687. case GGML_TYPE_IQ2_XXS:
  688. {
  689. nsg = N_SG_IQ2_XXS;
  690. nr0 = N_R0_IQ2_XXS;
  691. smem = 256*8+128;
  692. } break;
  693. case GGML_TYPE_IQ2_XS:
  694. {
  695. nsg = N_SG_IQ2_XS;
  696. nr0 = N_R0_IQ2_XS;
  697. smem = 512*8+128;
  698. } break;
  699. case GGML_TYPE_IQ3_XXS:
  700. {
  701. nsg = N_SG_IQ3_XXS;
  702. nr0 = N_R0_IQ3_XXS;
  703. smem = 256*4+128;
  704. } break;
  705. case GGML_TYPE_IQ3_S:
  706. {
  707. nsg = N_SG_IQ3_S;
  708. nr0 = N_R0_IQ3_S;
  709. smem = 512*4;
  710. } break;
  711. case GGML_TYPE_IQ2_S:
  712. {
  713. nsg = N_SG_IQ2_S;
  714. nr0 = N_R0_IQ2_S;
  715. } break;
  716. case GGML_TYPE_IQ1_S:
  717. {
  718. nsg = N_SG_IQ1_S;
  719. nr0 = N_R0_IQ1_S;
  720. } break;
  721. case GGML_TYPE_IQ1_M:
  722. {
  723. nsg = N_SG_IQ1_M;
  724. nr0 = N_R0_IQ1_M;
  725. } break;
  726. case GGML_TYPE_IQ4_NL:
  727. {
  728. nsg = N_SG_IQ4_NL;
  729. nr0 = N_R0_IQ4_NL;
  730. smem = 32*sizeof(float);
  731. } break;
  732. case GGML_TYPE_IQ4_XS:
  733. {
  734. nsg = N_SG_IQ4_XS;
  735. nr0 = N_R0_IQ4_XS;
  736. smem = 32*sizeof(float);
  737. } break;
  738. default:
  739. {
  740. GGML_LOG_ERROR("Asserting on type %d\n", (int)op->src[2]->type);
  741. GGML_ABORT("not implemented");
  742. }
  743. };
  744. snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
  745. snprintf(name, 256, "%s_nsg=%d", base, nsg);
  746. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  747. if (!res.pipeline) {
  748. ggml_metal_cv_t cv = ggml_metal_cv_init();
  749. ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
  750. res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
  751. ggml_metal_cv_free(cv);
  752. }
  753. res.nr0 = nr0;
  754. res.nr1 = nr1;
  755. res.nsg = nsg;
  756. res.smem = smem;
  757. return res;
  758. }
  759. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) {
  760. GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
  761. GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
  762. GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
  763. char base[256];
  764. char name[256];
  765. snprintf(base, 256, "kernel_argmax_%s", ggml_type_name(op->src[0]->type));
  766. snprintf(name, 256, "%s", base);
  767. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  768. if (!res.pipeline) {
  769. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  770. }
  771. res.smem = 32*(sizeof(float) + sizeof(int32_t));
  772. return res;
  773. }
  774. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) {
  775. assert(op->op == GGML_OP_ARGSORT);
  776. char base[256];
  777. char name[256];
  778. ggml_sort_order order = (ggml_sort_order) op->op_params[0];
  779. const char * order_str = "undefined";
  780. switch (order) {
  781. case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
  782. case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
  783. default: GGML_ABORT("fatal error");
  784. };
  785. snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
  786. snprintf(name, 256, "%s", base);
  787. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  788. if (!res.pipeline) {
  789. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  790. }
  791. return res;
  792. }
  793. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
  794. assert(op->op == GGML_OP_ARGSORT);
  795. char base[256];
  796. char name[256];
  797. ggml_sort_order order = (ggml_sort_order) op->op_params[0];
  798. const char * order_str = "undefined";
  799. switch (order) {
  800. case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
  801. case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
  802. default: GGML_ABORT("fatal error");
  803. };
  804. snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
  805. snprintf(name, 256, "%s", base);
  806. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  807. if (!res.pipeline) {
  808. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  809. }
  810. return res;
  811. }
  812. // note: reuse the argsort kernel for top_k
  813. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) {
  814. assert(op->op == GGML_OP_TOP_K);
  815. char base[256];
  816. char name[256];
  817. // note: the top_k kernel is always descending order
  818. ggml_sort_order order = GGML_SORT_ORDER_DESC;
  819. const char * order_str = "undefined";
  820. switch (order) {
  821. case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
  822. case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
  823. default: GGML_ABORT("fatal error");
  824. };
  825. snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
  826. snprintf(name, 256, "%s", base);
  827. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  828. if (!res.pipeline) {
  829. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  830. }
  831. return res;
  832. }
  833. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
  834. assert(op->op == GGML_OP_TOP_K);
  835. char base[256];
  836. char name[256];
  837. ggml_sort_order order = GGML_SORT_ORDER_DESC;
  838. const char * order_str = "undefined";
  839. switch (order) {
  840. case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
  841. case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
  842. default: GGML_ABORT("fatal error");
  843. };
  844. snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
  845. snprintf(name, 256, "%s", base);
  846. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  847. if (!res.pipeline) {
  848. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  849. }
  850. return res;
  851. }
  852. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
  853. ggml_metal_library_t lib,
  854. const struct ggml_tensor * op,
  855. bool has_mask,
  856. int32_t ncpsg) {
  857. assert(op->op == GGML_OP_FLASH_ATTN_EXT);
  858. GGML_UNUSED(op);
  859. char base[256];
  860. char name[256];
  861. snprintf(base, 256, "kernel_%s",
  862. "flash_attn_ext_pad");
  863. snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
  864. base,
  865. has_mask,
  866. ncpsg);
  867. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  868. if (!res.pipeline) {
  869. ggml_metal_cv_t cv = ggml_metal_cv_init();
  870. ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
  871. //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
  872. //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
  873. //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
  874. //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
  875. //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
  876. //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
  877. //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
  878. //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
  879. ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
  880. res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
  881. ggml_metal_cv_free(cv);
  882. }
  883. return res;
  884. }
  885. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk(
  886. ggml_metal_library_t lib,
  887. const struct ggml_tensor * op,
  888. int32_t nqptg,
  889. int32_t ncpsg) {
  890. assert(op->op == GGML_OP_FLASH_ATTN_EXT);
  891. GGML_UNUSED(op);
  892. char base[256];
  893. char name[256];
  894. snprintf(base, 256, "kernel_%s",
  895. "flash_attn_ext_blk");
  896. snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d",
  897. base,
  898. nqptg,
  899. ncpsg);
  900. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  901. if (!res.pipeline) {
  902. ggml_metal_cv_t cv = ggml_metal_cv_init();
  903. //ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
  904. //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
  905. //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
  906. //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);
  907. //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
  908. //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
  909. //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
  910. //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
  911. ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
  912. ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
  913. res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
  914. ggml_metal_cv_free(cv);
  915. }
  916. return res;
  917. }
  918. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext(
  919. ggml_metal_library_t lib,
  920. const ggml_tensor * op,
  921. bool has_mask,
  922. bool has_sinks,
  923. bool has_bias,
  924. bool has_scap,
  925. bool has_kvpad,
  926. int32_t nsg) {
  927. assert(op->op == GGML_OP_FLASH_ATTN_EXT);
  928. char base[256];
  929. char name[256];
  930. const int32_t dk = (int32_t) op->src[1]->ne[0];
  931. const int32_t dv = (int32_t) op->src[2]->ne[0];
  932. const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
  933. const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
  934. // do bounds checks for the mask?
  935. const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
  936. snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
  937. "flash_attn_ext",
  938. ggml_type_name(op->src[1]->type),
  939. dk,
  940. dv);
  941. snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
  942. base,
  943. has_mask,
  944. has_sinks,
  945. has_bias,
  946. has_scap,
  947. has_kvpad,
  948. bc_mask,
  949. ns10,
  950. ns20,
  951. nsg);
  952. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  953. if (!res.pipeline) {
  954. ggml_metal_cv_t cv = ggml_metal_cv_init();
  955. ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0);
  956. ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
  957. ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
  958. ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
  959. ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
  960. ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
  961. ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
  962. ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
  963. ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22);
  964. res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
  965. ggml_metal_cv_free(cv);
  966. }
  967. return res;
  968. }
  969. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec(
  970. ggml_metal_library_t lib,
  971. const ggml_tensor * op,
  972. bool has_mask,
  973. bool has_sinks,
  974. bool has_bias,
  975. bool has_scap,
  976. bool has_kvpad,
  977. int32_t nsg,
  978. int32_t nwg) {
  979. assert(op->op == GGML_OP_FLASH_ATTN_EXT);
  980. char base[256];
  981. char name[256];
  982. const int32_t dk = (int32_t) op->src[1]->ne[0];
  983. const int32_t dv = (int32_t) op->src[2]->ne[0];
  984. const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
  985. const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
  986. snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
  987. "flash_attn_ext_vec",
  988. ggml_type_name(op->src[1]->type),
  989. dk,
  990. dv);
  991. snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
  992. base,
  993. has_mask,
  994. has_sinks,
  995. has_bias,
  996. has_scap,
  997. has_kvpad,
  998. ns10,
  999. ns20,
  1000. nsg, nwg);
  1001. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  1002. if (!res.pipeline) {
  1003. ggml_metal_cv_t cv = ggml_metal_cv_init();
  1004. ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0);
  1005. ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
  1006. ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
  1007. ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
  1008. ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
  1009. ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
  1010. ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
  1011. ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22);
  1012. ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23);
  1013. res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
  1014. ggml_metal_cv_free(cv);
  1015. }
  1016. return res;
  1017. }
  1018. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
  1019. ggml_metal_library_t lib,
  1020. const ggml_tensor * op,
  1021. int32_t dv,
  1022. int32_t nwg) {
  1023. assert(op->op == GGML_OP_FLASH_ATTN_EXT);
  1024. char base[256];
  1025. char name[256];
  1026. snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
  1027. snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg);
  1028. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  1029. if (!res.pipeline) {
  1030. ggml_metal_cv_t cv = ggml_metal_cv_init();
  1031. ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0);
  1032. ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1);
  1033. res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
  1034. ggml_metal_cv_free(cv);
  1035. }
  1036. return res;
  1037. GGML_UNUSED(op);
  1038. }
  1039. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(
  1040. ggml_metal_library_t lib,
  1041. ggml_op op,
  1042. int32_t n_fuse,
  1043. bool row) {
  1044. char base[256];
  1045. char name[256];
  1046. const char * op_str = "undefined";
  1047. switch (op) {
  1048. case GGML_OP_ADD: op_str = "add"; break;
  1049. case GGML_OP_SUB: op_str = "sub"; break;
  1050. case GGML_OP_MUL: op_str = "mul"; break;
  1051. case GGML_OP_DIV: op_str = "div"; break;
  1052. default: GGML_ABORT("fatal error");
  1053. };
  1054. if (row) {
  1055. snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse);
  1056. } else {
  1057. snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse);
  1058. }
  1059. snprintf(name, 256, "%s", base);
  1060. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  1061. if (!res.pipeline) {
  1062. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  1063. }
  1064. return res;
  1065. }
  1066. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
  1067. assert(op->op == GGML_OP_L2_NORM);
  1068. GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
  1069. GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
  1070. char base[256];
  1071. char name[256];
  1072. snprintf(base, 256, "kernel_l2_norm_f32");
  1073. snprintf(name, 256, "%s", base);
  1074. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  1075. if (!res.pipeline) {
  1076. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  1077. }
  1078. res.smem = 32*sizeof(float);
  1079. return res;
  1080. }
  1081. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
  1082. assert(op->op == GGML_OP_GROUP_NORM);
  1083. GGML_ASSERT(ggml_is_contiguous(op->src[0]));
  1084. char base[256];
  1085. char name[256];
  1086. snprintf(base, 256, "kernel_group_norm_f32");
  1087. snprintf(name, 256, "%s", base);
  1088. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  1089. if (!res.pipeline) {
  1090. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  1091. }
  1092. res.smem = 32*sizeof(float);
  1093. return res;
  1094. }
  1095. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) {
  1096. assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM);
  1097. GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
  1098. char base[256];
  1099. char name[256];
  1100. const char * suffix = "";
  1101. if (op->ne[0] % 4 == 0) {
  1102. suffix = "_4";
  1103. }
  1104. switch (op->op) {
  1105. case GGML_OP_NORM:
  1106. switch (n_fuse) {
  1107. case 1: snprintf(base, 256, "kernel_norm_f32%s", suffix); break;
  1108. case 2: snprintf(base, 256, "kernel_norm_mul_f32%s", suffix); break;
  1109. case 3: snprintf(base, 256, "kernel_norm_mul_add_f32%s", suffix); break;
  1110. default: GGML_ABORT("fatal error");
  1111. } break;
  1112. case GGML_OP_RMS_NORM:
  1113. switch (n_fuse) {
  1114. case 1: snprintf(base, 256, "kernel_rms_norm_f32%s", suffix); break;
  1115. case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32%s", suffix); break;
  1116. case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32%s", suffix); break;
  1117. default: GGML_ABORT("fatal error");
  1118. } break;
  1119. default: GGML_ABORT("fatal error");
  1120. }
  1121. snprintf(name, 256, "%s", base);
  1122. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  1123. if (!res.pipeline) {
  1124. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  1125. }
  1126. res.smem = 32*sizeof(float);
  1127. return res;
  1128. }
  1129. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) {
  1130. assert(op->op == GGML_OP_ROPE);
  1131. char base[256];
  1132. char name[256];
  1133. const int mode = ((const int32_t *) op->op_params)[2];
  1134. const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
  1135. const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
  1136. const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
  1137. const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
  1138. if (is_neox) {
  1139. snprintf(base, 256, "kernel_rope_neox_%s", ggml_type_name(op->src[0]->type));
  1140. } else if ((is_mrope || is_imrope) && !is_vision) {
  1141. GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
  1142. snprintf(base, 256, "kernel_rope_multi_%s", ggml_type_name(op->src[0]->type));
  1143. } else if (is_vision) {
  1144. GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
  1145. snprintf(base, 256, "kernel_rope_vision_%s", ggml_type_name(op->src[0]->type));
  1146. } else {
  1147. snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type));
  1148. }
  1149. snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
  1150. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  1151. if (!res.pipeline) {
  1152. ggml_metal_cv_t cv = ggml_metal_cv_init();
  1153. ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
  1154. res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
  1155. ggml_metal_cv_free(cv);
  1156. }
  1157. return res;
  1158. }
  1159. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) {
  1160. assert(op->op == GGML_OP_IM2COL);
  1161. GGML_ASSERT(ggml_is_contiguous(op->src[1]));
  1162. GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
  1163. GGML_ASSERT(op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
  1164. char base[256];
  1165. char name[256];
  1166. snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type));
  1167. snprintf(name, 256, "%s", base);
  1168. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  1169. if (!res.pipeline) {
  1170. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  1171. }
  1172. return res;
  1173. }
  1174. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
  1175. assert(op->op == GGML_OP_CONV_TRANSPOSE_1D);
  1176. GGML_ASSERT(ggml_is_contiguous(op->src[0]));
  1177. GGML_ASSERT(ggml_is_contiguous(op->src[1]));
  1178. GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
  1179. GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
  1180. GGML_ASSERT(op->type == GGML_TYPE_F32);
  1181. char base[256];
  1182. char name[256];
  1183. snprintf(base, 256, "kernel_conv_transpose_1d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
  1184. snprintf(name, 256, "%s", base);
  1185. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  1186. if (!res.pipeline) {
  1187. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  1188. }
  1189. return res;
  1190. }
  1191. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
  1192. assert(op->op == GGML_OP_CONV_TRANSPOSE_2D);
  1193. GGML_ASSERT(ggml_is_contiguous(op->src[0]));
  1194. GGML_ASSERT(ggml_is_contiguous(op->src[1]));
  1195. GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
  1196. GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
  1197. GGML_ASSERT(op->type == GGML_TYPE_F32);
  1198. char base[256];
  1199. char name[256];
  1200. snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
  1201. snprintf(name, 256, "%s", base);
  1202. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  1203. if (!res.pipeline) {
  1204. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  1205. }
  1206. return res;
  1207. }
  1208. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
  1209. assert(op->op == GGML_OP_CONV_2D);
  1210. GGML_ASSERT(ggml_is_contiguous(op->src[0]));
  1211. GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
  1212. GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
  1213. GGML_ASSERT(op->type == GGML_TYPE_F32);
  1214. char base[256];
  1215. char name[256];
  1216. snprintf(base, 256, "kernel_conv_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
  1217. snprintf(name, 256, "%s", base);
  1218. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  1219. if (!res.pipeline) {
  1220. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  1221. }
  1222. return res;
  1223. }
  1224. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
  1225. assert(op->op == GGML_OP_UPSCALE);
  1226. char base[256];
  1227. char name[256];
  1228. snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type));
  1229. snprintf(name, 256, "%s", base);
  1230. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  1231. if (!res.pipeline) {
  1232. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  1233. }
  1234. return res;
  1235. }
  1236. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) {
  1237. assert(op->op == GGML_OP_PAD);
  1238. char base[256];
  1239. char name[256];
  1240. snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type));
  1241. snprintf(name, 256, "%s", base);
  1242. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  1243. if (res.pipeline) {
  1244. return res;
  1245. }
  1246. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  1247. return res;
  1248. }
  1249. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
  1250. assert(op->op == GGML_OP_PAD_REFLECT_1D);
  1251. char base[256];
  1252. char name[256];
  1253. snprintf(base, 256, "kernel_pad_reflect_1d_%s", ggml_type_name(op->src[0]->type));
  1254. snprintf(name, 256, "%s", base);
  1255. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  1256. if (!res.pipeline) {
  1257. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  1258. }
  1259. return res;
  1260. }
  1261. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) {
  1262. assert(op->op == GGML_OP_ARANGE);
  1263. char base[256];
  1264. char name[256];
  1265. snprintf(base, 256, "kernel_arange_%s", ggml_type_name(op->type));
  1266. snprintf(name, 256, "%s", base);
  1267. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  1268. if (!res.pipeline) {
  1269. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  1270. }
  1271. return res;
  1272. }
  1273. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) {
  1274. assert(op->op == GGML_OP_TIMESTEP_EMBEDDING);
  1275. char base[256];
  1276. char name[256];
  1277. snprintf(base, 256, "kernel_timestep_embedding_%s", ggml_type_name(op->src[0]->type));
  1278. snprintf(name, 256, "%s", base);
  1279. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  1280. if (!res.pipeline) {
  1281. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  1282. }
  1283. return res;
  1284. }
  1285. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
  1286. assert(op->op == GGML_OP_OPT_STEP_ADAMW);
  1287. char base[256];
  1288. char name[256];
  1289. snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
  1290. snprintf(name, 256, "%s", base);
  1291. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  1292. if (!res.pipeline) {
  1293. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  1294. }
  1295. return res;
  1296. }
  1297. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
  1298. assert(op->op == GGML_OP_OPT_STEP_SGD);
  1299. char base[256];
  1300. char name[256];
  1301. snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type));
  1302. snprintf(name, 256, "%s", base);
  1303. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  1304. if (!res.pipeline) {
  1305. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  1306. }
  1307. return res;
  1308. }
  1309. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset(ggml_metal_library_t lib, const ggml_tensor * op) {
  1310. GGML_ASSERT(op->type == GGML_TYPE_I64);
  1311. char base[256];
  1312. char name[256];
  1313. snprintf(base, 256, "kernel_memset_%s", ggml_type_name(op->type));
  1314. snprintf(name, 256, "%s", base);
  1315. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  1316. if (!res.pipeline) {
  1317. res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  1318. }
  1319. return res;
  1320. }
  1321. ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal(ggml_metal_library_t lib, const ggml_tensor * op) {
  1322. assert(op->op == GGML_OP_COUNT_EQUAL);
  1323. GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne);
  1324. GGML_ASSERT(op->src[0]->type == op->src[1]->type);
  1325. GGML_ASSERT(op->src[0]->type == GGML_TYPE_I32);
  1326. GGML_ASSERT(op->type == GGML_TYPE_I64);
  1327. // note: the kernel only supports i32 output due to metal atomic add only supporting atomic_int
  1328. GGML_ASSERT(ggml_nelements(op->src[0]) < (1LL << 31));
  1329. char base[256];
  1330. char name[256];
  1331. int nsg = 1;
  1332. while (32*nsg < ne00 && nsg < 32) {
  1333. nsg *= 2;
  1334. }
  1335. snprintf(base, 256, "kernel_count_equal_%s", ggml_type_name(op->src[0]->type));
  1336. snprintf(name, 256, "%s_nsg=%d", base, nsg);
  1337. ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  1338. if (!res.pipeline) {
  1339. ggml_metal_cv_t cv = ggml_metal_cv_init();
  1340. ggml_metal_cv_set_int16(cv, nsg, FC_COUNT_EQUAL + 0);
  1341. res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
  1342. ggml_metal_cv_free(cv);
  1343. }
  1344. res.smem = 32 * sizeof(int32_t);
  1345. res.nsg = nsg;
  1346. return res;
  1347. }