vulkan-shaders-gen.cpp 62 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194
  1. #include <iostream>
  2. #include <fstream>
  3. #include <sstream>
  4. #include <string>
  5. #include <stdexcept>
  6. #include <array>
  7. #include <vector>
  8. #include <map>
  9. #include <thread>
  10. #include <mutex>
  11. #include <future>
  12. #include <queue>
  13. #include <condition_variable>
  14. #include <cstdio>
  15. #include <cstring>
  16. #include <cstdlib>
  17. #include <cassert>
  18. #include <algorithm>
  19. #include <sys/stat.h>
  20. #include <sys/types.h>
  21. #include <filesystem>
  22. #ifdef _WIN32
  23. #define NOMINMAX
  24. #include <windows.h>
  25. #include <direct.h> // For _mkdir on Windows
  26. #else
  27. #include <unistd.h>
  28. #include <sys/wait.h>
  29. #include <fcntl.h>
  30. #endif
  31. #define ASYNCIO_CONCURRENCY 64
  32. std::mutex lock;
  33. std::vector<std::pair<std::string, std::string>> shader_fnames;
  34. std::locale c_locale("C");
  35. std::string GLSLC = "glslc";
  36. std::string input_filepath = "";
  37. std::string output_dir = "/tmp";
  38. std::string target_hpp = "";
  39. std::string target_cpp = "";
  40. const std::vector<std::string> type_names = {
  41. "f32",
  42. "f16",
  43. "q4_0",
  44. "q4_1",
  45. "q5_0",
  46. "q5_1",
  47. "q8_0",
  48. "q2_k",
  49. "q3_k",
  50. "q4_k",
  51. "q5_k",
  52. "q6_k",
  53. "iq1_s",
  54. "iq1_m",
  55. "iq2_xxs",
  56. "iq2_xs",
  57. "iq2_s",
  58. "iq3_xxs",
  59. "iq3_s",
  60. "iq4_xs",
  61. "iq4_nl",
  62. "mxfp4",
  63. "bf16",
  64. };
  65. enum MatMulIdType {
  66. NONE,
  67. DEFAULT,
  68. SUBGROUP,
  69. };
  70. namespace {
  71. void execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
  72. #ifdef _WIN32
  73. HANDLE stdout_read, stdout_write;
  74. HANDLE stderr_read, stderr_write;
  75. SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE };
  76. if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) ||
  77. !SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) {
  78. throw std::runtime_error("Failed to create stdout pipe");
  79. }
  80. if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) ||
  81. !SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) {
  82. throw std::runtime_error("Failed to create stderr pipe");
  83. }
  84. PROCESS_INFORMATION pi;
  85. STARTUPINFOA si = {};
  86. si.cb = sizeof(STARTUPINFOA);
  87. si.dwFlags = STARTF_USESTDHANDLES;
  88. si.hStdOutput = stdout_write;
  89. si.hStdError = stderr_write;
  90. std::string cmd;
  91. for (const auto& part : command) {
  92. cmd += part + " ";
  93. }
  94. if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) {
  95. throw std::runtime_error("Failed to create process");
  96. }
  97. CloseHandle(stdout_write);
  98. CloseHandle(stderr_write);
  99. std::array<char, 128> buffer;
  100. DWORD bytes_read;
  101. while (ReadFile(stdout_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
  102. stdout_str.append(buffer.data(), bytes_read);
  103. }
  104. while (ReadFile(stderr_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
  105. stderr_str.append(buffer.data(), bytes_read);
  106. }
  107. CloseHandle(stdout_read);
  108. CloseHandle(stderr_read);
  109. WaitForSingleObject(pi.hProcess, INFINITE);
  110. CloseHandle(pi.hProcess);
  111. CloseHandle(pi.hThread);
  112. #else
  113. int stdout_pipe[2];
  114. int stderr_pipe[2];
  115. if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) {
  116. throw std::runtime_error("Failed to create pipes");
  117. }
  118. pid_t pid = fork();
  119. if (pid < 0) {
  120. throw std::runtime_error("Failed to fork process");
  121. }
  122. std::vector<char*> argv;
  123. for (std::string& part : command) {
  124. argv.push_back(part.data());
  125. }
  126. argv.push_back(nullptr);
  127. if (pid == 0) {
  128. close(stdout_pipe[0]);
  129. close(stderr_pipe[0]);
  130. dup2(stdout_pipe[1], STDOUT_FILENO);
  131. dup2(stderr_pipe[1], STDERR_FILENO);
  132. close(stdout_pipe[1]);
  133. close(stderr_pipe[1]);
  134. execvp(argv[0], argv.data());
  135. _exit(EXIT_FAILURE);
  136. } else {
  137. close(stdout_pipe[1]);
  138. close(stderr_pipe[1]);
  139. std::array<char, 128> buffer;
  140. ssize_t bytes_read;
  141. while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) {
  142. stdout_str.append(buffer.data(), bytes_read);
  143. }
  144. while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) {
  145. stderr_str.append(buffer.data(), bytes_read);
  146. }
  147. close(stdout_pipe[0]);
  148. close(stderr_pipe[0]);
  149. waitpid(pid, nullptr, 0);
  150. }
  151. #endif
  152. }
  153. bool directory_exists(const std::string& path) {
  154. struct stat info;
  155. if (stat(path.c_str(), &info) != 0) {
  156. return false; // Path doesn't exist or can't be accessed
  157. }
  158. return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory
  159. }
  160. bool create_directory(const std::string& path) {
  161. #ifdef _WIN32
  162. return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists
  163. #else
  164. return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions
  165. #endif
  166. }
  167. std::string to_uppercase(const std::string& input) {
  168. std::string result = input;
  169. for (char& c : result) {
  170. c = std::toupper(c);
  171. }
  172. return result;
  173. }
  174. bool string_starts_with(const std::string& str, const std::string& prefix) {
  175. if (prefix.size() > str.size()) {
  176. return false;
  177. }
  178. return std::equal(prefix.begin(), prefix.end(), str.begin());
  179. }
  180. bool string_ends_with(const std::string& str, const std::string& suffix) {
  181. if (suffix.size() > str.size()) {
  182. return false;
  183. }
  184. return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
  185. }
  186. bool is_quantized_type(const std::string& type_name) {
  187. return type_name != "f32" && type_name != "f16" && type_name != "bf16";
  188. }
  189. bool is_legacy_quant(const std::string& type_name) {
  190. return type_name == "q4_0" || type_name == "q4_1" || type_name == "q5_0" || type_name == "q5_1" || type_name == "q8_0";
  191. }
  192. bool is_k_quant(const std::string& type_name) {
  193. return string_ends_with(type_name, "_k");
  194. }
  195. bool is_iq_quant(const std::string& type_name) {
  196. return string_starts_with(type_name, "iq");
  197. }
  198. static const char path_separator = '/';
  199. std::string join_paths(const std::string& path1, const std::string& path2) {
  200. return path1 + path_separator + path2;
  201. }
  202. std::string basename(const std::string &path) {
  203. return path.substr(path.find_last_of("/\\") + 1);
  204. }
  205. std::stringstream make_generic_stringstream() {
  206. std::stringstream ss;
  207. ss.imbue(c_locale);
  208. return ss;
  209. }
  210. std::string read_binary_file(const std::string& path, bool may_not_exist = false) {
  211. FILE* f = fopen(path.c_str(), "rb");
  212. if (!f) {
  213. if (!may_not_exist) {
  214. std::cerr << "Error opening file: " << path << " (" << strerror(errno) << ")\n";
  215. }
  216. return {};
  217. }
  218. fseek(f, 0, SEEK_END);
  219. size_t size = ftell(f);
  220. fseek(f, 0, SEEK_SET);
  221. std::string data(size, '\0');
  222. size_t read_size = fread(data.data(), 1, size, f);
  223. fclose(f);
  224. if (read_size != size) {
  225. std::cerr << "Error reading file: " << path << " (" << strerror(errno) << ")\n";
  226. return {};
  227. }
  228. return data;
  229. }
  230. void write_binary_file(const std::string& path, const std::string& content) {
  231. FILE* f = fopen(path.c_str(), "wb");
  232. if (!f) {
  233. std::cerr << "Error opening file for writing: " << path << " (" << strerror(errno) << ")\n";
  234. return;
  235. }
  236. size_t write_size = fwrite(content.data(), 1, content.size(), f);
  237. fclose(f);
  238. if (write_size != content.size()) {
  239. std::cerr << "Error writing file: " << path << " (" << strerror(errno) << ")\n";
  240. return;
  241. }
  242. }
  243. void write_file_if_changed(const std::string& path, const std::string& content) {
  244. std::string existing = read_binary_file(path, true);
  245. if (existing != content) {
  246. write_binary_file(path, content);
  247. }
  248. }
  249. // variables to track number of compiles in progress
  250. static uint32_t compile_count = 0;
  251. static std::mutex compile_count_mutex;
  252. static std::condition_variable compile_count_cond;
  253. static bool generate_dep_file = true;
  254. void decrement_compile_count(uint32_t * count) {
  255. if (count) {
  256. std::lock_guard<std::mutex> guard(compile_count_mutex);
  257. assert(compile_count > 0);
  258. compile_count--;
  259. compile_count_cond.notify_all();
  260. }
  261. }
  262. using compile_count_guard = std::unique_ptr<uint32_t, decltype(&decrement_compile_count)>;
  263. compile_count_guard acquire_compile_slot() {
  264. // wait until fewer than N compiles are in progress.
  265. // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
  266. uint32_t N = std::max(1u, std::min(16u, std::thread::hardware_concurrency()));
  267. std::unique_lock<std::mutex> guard(compile_count_mutex);
  268. compile_count_cond.wait(guard, [N] { return compile_count < N; });
  269. compile_count++;
  270. return compile_count_guard(&compile_count, &decrement_compile_count);
  271. }
  272. void string_to_spv_func(std::string name, std::string in_path, std::string out_path, std::map<std::string, std::string> defines, bool coopmat, bool dep_file, compile_count_guard slot) {
  273. std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2";
  274. #ifdef _WIN32
  275. std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, "\"" + in_path + "\"", "-o", "\"" + out_path + "\""};
  276. #else
  277. std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, in_path, "-o", out_path};
  278. #endif
  279. // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734
  280. // disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344
  281. // disable spirv-opt for rope shaders for https://github.com/ggml-org/llama.cpp/issues/16860
  282. if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos) {
  283. cmd.push_back("-O");
  284. }
  285. if (dep_file) {
  286. cmd.push_back("-MD");
  287. cmd.push_back("-MF");
  288. #ifdef _WIN32
  289. cmd.push_back("\"" + target_cpp + ".d\"");
  290. #else
  291. cmd.push_back(target_cpp + ".d");
  292. #endif
  293. }
  294. #ifdef GGML_VULKAN_SHADER_DEBUG_INFO
  295. cmd.push_back("-g");
  296. #endif
  297. for (const auto& define : defines) {
  298. cmd.push_back("-D" + define.first + "=" + define.second);
  299. }
  300. std::string command;
  301. for (const auto& part : cmd) {
  302. command += part + " ";
  303. }
  304. std::string stdout_str, stderr_str;
  305. try {
  306. // std::cout << "Executing command: ";
  307. // for (const auto& part : cmd) {
  308. // std::cout << part << " ";
  309. // }
  310. // std::cout << std::endl;
  311. execute_command(cmd, stdout_str, stderr_str);
  312. if (!stderr_str.empty()) {
  313. std::cerr << "cannot compile " << name << "\n\n";
  314. for (const auto& part : cmd) {
  315. std::cerr << part << " ";
  316. }
  317. std::cerr << "\n\n" << stderr_str << std::endl;
  318. return;
  319. }
  320. if (dep_file) {
  321. // replace .spv output path with the embed .cpp path which is used as output in CMakeLists.txt
  322. std::string dep = read_binary_file(target_cpp + ".d", true);
  323. if (!dep.empty()) {
  324. size_t pos = dep.find(out_path);
  325. if (pos != std::string::npos) {
  326. dep.replace(pos, out_path.length(), target_cpp);
  327. }
  328. write_binary_file(target_cpp + ".d", dep);
  329. }
  330. }
  331. std::lock_guard<std::mutex> guard(lock);
  332. shader_fnames.push_back(std::make_pair(name, out_path));
  333. } catch (const std::exception& e) {
  334. std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
  335. }
  336. }
  337. std::map<std::string, std::string> merge_maps(const std::map<std::string, std::string>& a, const std::map<std::string, std::string>& b) {
  338. std::map<std::string, std::string> result = a;
  339. result.insert(b.begin(), b.end());
  340. return result;
  341. }
  342. static std::vector<std::future<void>> compiles;
  343. void string_to_spv(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
  344. name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
  345. std::string out_path = join_paths(output_dir, name + ".spv");
  346. if (input_filepath == "") {
  347. // No input source to compile, only generate header for all shaders
  348. shader_fnames.push_back(std::pair(name, out_path));
  349. return;
  350. } else if (basename(input_filepath) != source) {
  351. // Only compile shader variants matching the input filename
  352. return;
  353. }
  354. compile_count_guard slot = acquire_compile_slot();
  355. compiles.push_back(std::async(
  356. string_to_spv_func, name, input_filepath, out_path, defines, coopmat, generate_dep_file, std::move(slot)));
  357. // Don't write the same dep file from multiple processes
  358. generate_dep_file = false;
  359. }
  360. void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) {
  361. std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
  362. std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
  363. std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
  364. std::map<std::string, std::string> base_dict;
  365. std::string shader_name = "matmul";
  366. if (matmul_id_type == MatMulIdType::DEFAULT) {
  367. base_dict["MUL_MAT_ID"] = "1";
  368. shader_name = "matmul_id";
  369. } else if (matmul_id_type == MatMulIdType::SUBGROUP) {
  370. base_dict["MUL_MAT_ID"] = "1";
  371. base_dict["MUL_MAT_ID_USE_SUBGROUPS"] = "1";
  372. shader_name = "matmul_id_subgroup";
  373. }
  374. if (fp16) {
  375. base_dict["FLOAT16"] = "1";
  376. }
  377. base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float";
  378. base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2" : "vec2";
  379. if (f16acc) {
  380. base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
  381. }
  382. if (coopmat) {
  383. base_dict["COOPMAT"] = "1";
  384. }
  385. const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
  386. auto const &FLOAT_TYPE = [&](int vec, const std::string &t) -> std::string {
  387. switch (vec) {
  388. case 1:
  389. if (t == "bf16") {
  390. // scalar path promotes to float
  391. if (!coopmat && !coopmat2) {
  392. return "float";
  393. }
  394. return "bfloat16_t";
  395. }
  396. if (coopmat2 || fp16) {
  397. return "float16_t";
  398. }
  399. return "float";
  400. case 2:
  401. if (t == "bf16") {
  402. // scalar path promotes to float
  403. if (!coopmat && !coopmat2) {
  404. return "vec2";
  405. }
  406. return "bf16vec2";
  407. }
  408. if (coopmat2 || fp16) {
  409. return "f16vec2";
  410. }
  411. return "vec2";
  412. case 4:
  413. if (t == "bf16") {
  414. // scalar path promotes to float
  415. if (!coopmat && !coopmat2) {
  416. return "vec4";
  417. }
  418. return "bf16vec4";
  419. }
  420. if (coopmat2 || fp16) {
  421. return "f16vec4";
  422. }
  423. return "vec4";
  424. case 8:
  425. if (t == "bf16") {
  426. // scalar path promotes to float
  427. if (!coopmat && !coopmat2) {
  428. return "mat2x4";
  429. }
  430. throw std::runtime_error("bf16 vec8 not supported");
  431. }
  432. if (coopmat2 || fp16) {
  433. return "f16mat2x4";
  434. }
  435. return "mat2x4";
  436. default:
  437. throw std::runtime_error("invalid vector size");
  438. }
  439. };
  440. const std::map<std::string, std::string> float_type_dict_f16 = {
  441. {"FLOAT_TYPE", FLOAT_TYPE(1, "f16")},
  442. {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "f16")},
  443. {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "f16")},
  444. {"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, "f16")},
  445. };
  446. // Shaders with f16 B_TYPE
  447. string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
  448. string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
  449. string_to_spv(shader_name + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
  450. string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
  451. // bf16
  452. {
  453. // For aligned matmul loads
  454. std::string load_vec_a = coopmat2 ? "1" : "4";
  455. // scalar path promotes to float
  456. std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
  457. const std::map<std::string, std::string> float_type_dict_bf16 = {
  458. {"FLOAT_TYPE", FLOAT_TYPE(1, "bf16")},
  459. {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "bf16")},
  460. {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "bf16")},
  461. };
  462. // If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader
  463. #if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
  464. if (!(coopmat || coopmat2))
  465. #endif
  466. {
  467. string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
  468. string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
  469. }
  470. }
  471. for (const auto& tname : type_names) {
  472. std::string load_vec_quant = "2";
  473. if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
  474. load_vec_quant = "8";
  475. else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
  476. load_vec_quant = "4";
  477. if (tname == "bf16") {
  478. continue;
  479. }
  480. std::string data_a_key = "DATA_A_" + to_uppercase(tname);
  481. // For unaligned, load one at a time for f32/f16, or two at a time for quants
  482. std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
  483. // For aligned matmul loads
  484. std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
  485. const std::map<std::string, std::string> float_type_dict = {
  486. {"FLOAT_TYPE", FLOAT_TYPE(1, tname)},
  487. {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, tname)},
  488. {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, tname)},
  489. {"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, tname)},
  490. };
  491. // don't generate f32 variants for coopmat2
  492. if (!coopmat2) {
  493. string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
  494. string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
  495. }
  496. if (tname != "f16" && tname != "f32") {
  497. string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
  498. string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
  499. }
  500. #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
  501. // Integer dot mmq performs better with f32 accumulators
  502. if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {
  503. string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
  504. }
  505. #endif
  506. }
  507. }
  508. void process_shaders() {
  509. std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
  510. // matmul
  511. for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
  512. // No coopmats
  513. // fp32
  514. matmul_shaders(false, matmul_id_type, false, false, false);
  515. // fp16, fp32acc and fp16acc
  516. matmul_shaders(true, matmul_id_type, false, false, false);
  517. matmul_shaders(true, matmul_id_type, false, false, true);
  518. if (matmul_id_type != MatMulIdType::DEFAULT) {
  519. #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
  520. // Coopmat, fp32acc and fp16acc
  521. matmul_shaders(true, matmul_id_type, true, false, false);
  522. matmul_shaders(true, matmul_id_type, true, false, true);
  523. #endif
  524. #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
  525. // Coopmat2, fp32acc and fp16acc
  526. matmul_shaders(true, matmul_id_type, false, true, false);
  527. matmul_shaders(true, matmul_id_type, false, true, true);
  528. #endif
  529. }
  530. }
  531. // flash attention
  532. for (const auto& f16acc : {false, true}) {
  533. std::map<std::string, std::string> fa_base_dict = base_dict;
  534. fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
  535. fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4";
  536. if (f16acc) {
  537. fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
  538. }
  539. for (const auto& tname : type_names) {
  540. if (tname == "bf16") continue;
  541. #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
  542. if (tname == "f16") {
  543. string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
  544. merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc);
  545. } else {
  546. std::string data_a_key = "DATA_A_" + to_uppercase(tname);
  547. string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
  548. merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
  549. }
  550. #endif
  551. #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
  552. if (tname == "f16") {
  553. string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
  554. merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
  555. } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
  556. std::string data_a_key = "DATA_A_" + to_uppercase(tname);
  557. string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
  558. merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
  559. }
  560. #endif
  561. if (tname == "f16") {
  562. string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
  563. merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
  564. } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
  565. std::string data_a_key = "DATA_A_" + to_uppercase(tname);
  566. string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
  567. merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
  568. }
  569. }
  570. }
  571. for (const auto& tname : type_names) {
  572. // mul mat vec
  573. std::string data_a_key = "DATA_A_" + to_uppercase(tname);
  574. std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
  575. string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
  576. string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
  577. string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
  578. string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
  579. string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
  580. string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
  581. string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
  582. string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
  583. string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
  584. // mul mat vec with integer dot product
  585. #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
  586. if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname)) {
  587. string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
  588. string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
  589. string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
  590. string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
  591. string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
  592. string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
  593. }
  594. #endif
  595. // Dequant shaders
  596. if (tname != "f16" && tname != "bf16") {
  597. string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
  598. }
  599. shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp";
  600. if (tname == "f16") {
  601. string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{"TEMP_TYPE", "FLOAT_TYPE"}, {data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
  602. } else {
  603. string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{"TEMP_TYPE", "FLOAT_TYPE"}, {data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
  604. }
  605. string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{"TEMP_TYPE", "FLOAT_TYPE"}, {data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
  606. }
  607. string_to_spv("get_rows_i32", "get_rows.comp", {{"TEMP_TYPE", "uint"}, {"A_TYPE", "uint"}, {"B_TYPE", "int"}, {"D_TYPE", "uint"}});
  608. string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
  609. string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
  610. string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
  611. // Norms
  612. string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  613. string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  614. string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
  615. string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
  616. string_to_spv("rms_norm_mul_rope_f32_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float"}, {"RMS_NORM_ROPE_FUSION", "1"}}));
  617. string_to_spv("rms_norm_mul_rope_f32_f16_rte", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}, {"RTE16", "1"}}));
  618. string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
  619. string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  620. string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  621. string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
  622. string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
  623. string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
  624. string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
  625. string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  626. string_to_spv("contig_cpy_f32_i32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}});
  627. string_to_spv("contig_cpy_i32_f32", "contig_copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}});
  628. string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
  629. string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
  630. string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
  631. string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
  632. string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}});
  633. string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}});
  634. string_to_spv("cpy_transpose_16", "copy_transpose.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}});
  635. string_to_spv("cpy_transpose_32", "copy_transpose.comp", {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}});
  636. for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
  637. string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  638. string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
  639. string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  640. }
  641. for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
  642. string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  643. string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
  644. string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  645. string_to_spv("set_rows_" + t + "_i64_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
  646. }
  647. auto get_type_str = [](bool f16) {
  648. return f16 ? "float16_t" : "float";
  649. };
  650. auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
  651. std::string s;
  652. s += std::string(src0_f16 ? "_f16" : "_f32");
  653. s += std::string(src1_f16 ? "_f16" : "_f32");
  654. s += std::string(dst_f16 ? "_f16" : "_f32");
  655. return s;
  656. };
  657. for (std::string op : {"add", "sub", "mul", "div", "add_rms", }) {
  658. for (auto src0_f16 : {false, true}) {
  659. for (auto src1_f16 : {false, true}) {
  660. for (auto dst_f16 : {false, true}) {
  661. for (auto rte : {false, true}) {
  662. auto source = op == "add_rms" ? std::string("add") : op;
  663. auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
  664. auto add_rms = op == "add_rms" ? "1" : "0";
  665. string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}});
  666. }
  667. }
  668. }
  669. }
  670. }
  671. string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  672. string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  673. string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
  674. string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {});
  675. string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
  676. string_to_spv("quantize_q8_1_subgroup", "quantize_q8_1.comp", {{"USE_SUBGROUPS", "1"}});
  677. string_to_spv("quantize_q8_1_x4", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}});
  678. string_to_spv("quantize_q8_1_x4_subgroup", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}, {"USE_SUBGROUPS", "1"}});
  679. string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  680. string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  681. string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  682. string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  683. string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  684. string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  685. string_to_spv("sqrt_f32", "sqrt.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  686. string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  687. string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  688. string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  689. string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  690. string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
  691. string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
  692. string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}});
  693. string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
  694. for (auto rte : {false, true}) {
  695. std::string suffix = rte ? "_rte" : "";
  696. string_to_spv("exp_f16" + suffix, "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
  697. string_to_spv("exp_f32" + suffix, "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"} , {"RTE16", rte ? "1" : "0"}});
  698. string_to_spv("log_f16" + suffix, "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
  699. string_to_spv("log_f32" + suffix, "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
  700. }
  701. string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  702. string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  703. string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  704. string_to_spv("gelu_erf_f32", "gelu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  705. string_to_spv("gelu_quick_f16", "gelu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  706. string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  707. string_to_spv("silu_f16", "silu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  708. string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  709. string_to_spv("relu_f16", "relu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  710. string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  711. string_to_spv("neg_f16", "neg.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  712. string_to_spv("neg_f32", "neg.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  713. string_to_spv("tanh_f16", "tanh.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  714. string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  715. string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  716. string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  717. string_to_spv("hardsigmoid_f16","hardsigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  718. string_to_spv("hardsigmoid_f32","hardsigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  719. string_to_spv("hardswish_f16", "hardswish.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  720. string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  721. string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  722. string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  723. string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  724. string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  725. string_to_spv("diag_f16", "diag.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  726. string_to_spv("diag_f32", "diag.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  727. string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  728. string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  729. string_to_spv("add1_f16_f16", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
  730. string_to_spv("add1_f16_f32", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
  731. string_to_spv("add1_f32_f32", "add1.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  732. string_to_spv("arange_f32", "arange.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  733. string_to_spv("fill_f32", "fill.comp", {{"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  734. string_to_spv("step_f16", "step.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  735. string_to_spv("step_f32", "step.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  736. string_to_spv("round_f16", "round.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  737. string_to_spv("round_f32", "round.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  738. string_to_spv("ceil_f16", "ceil.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  739. string_to_spv("ceil_f32", "ceil.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  740. string_to_spv("floor_f16", "floor.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  741. string_to_spv("floor_f32", "floor.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  742. string_to_spv("trunc_f16", "trunc.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  743. string_to_spv("trunc_f32", "trunc.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  744. for (auto rte : {false, true}) {
  745. std::string suffix = rte ? "_rte" : "";
  746. string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
  747. string_to_spv("geglu_f32" + suffix, "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
  748. string_to_spv("reglu_f16" + suffix, "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
  749. string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
  750. string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
  751. string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
  752. string_to_spv("swiglu_oai_f16" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
  753. string_to_spv("swiglu_oai_f32" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
  754. string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
  755. string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
  756. string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
  757. string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
  758. }
  759. string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  760. string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
  761. string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  762. string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
  763. string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
  764. string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
  765. string_to_spv("soft_max_large1_f32", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
  766. string_to_spv("soft_max_large2_f32", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
  767. string_to_spv("soft_max_large3_f32", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
  768. string_to_spv("soft_max_large1_f32_f16", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
  769. string_to_spv("soft_max_large2_f32_f16", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
  770. string_to_spv("soft_max_large3_f32_f16", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
  771. string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
  772. string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
  773. string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
  774. string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
  775. string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
  776. string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
  777. string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
  778. string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
  779. string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
  780. string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
  781. string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
  782. string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
  783. string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
  784. string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
  785. string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
  786. string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
  787. string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
  788. string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
  789. string_to_spv("topk_argsort_f32", "topk_argsort.comp", {{"A_TYPE", "float"}});
  790. string_to_spv("topk_nary_search_f32", "topk_nary_search.comp", {{"A_TYPE", "float"}});
  791. string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
  792. string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  793. string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
  794. string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  795. for (std::string dim_str : {"", "_3d"}) {
  796. for (bool bda : {false, true}) {
  797. std::string bda_str = bda ? "_bda" : "";
  798. std::string bda_def = bda ? "1" : "0";
  799. string_to_spv("im2col" + dim_str + "_f32" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"D_SIZE", "4"}, {"BDA", bda_def}}));
  800. string_to_spv("im2col" + dim_str + "_f32_f16" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"BDA", bda_def}}));
  801. string_to_spv("im2col" + dim_str + "_f32_f16_rte" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"RTE16", "1"}, {"BDA", bda_def}}));
  802. }
  803. }
  804. string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  805. string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
  806. string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  807. string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
  808. string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
  809. string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
  810. string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
  811. string_to_spv("solve_tri_f32", "solve_tri.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
  812. for (auto transpose : {false, true}) {
  813. for (auto unroll : {false, true}) {
  814. for (auto a_f16 : {false, true}) {
  815. std::map<std::string, std::string> defines = {
  816. {"A_TYPE", a_f16 ? "float16_t" : "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"},
  817. {"USE_COLLECTIVES", "1"}, {"UNROLL", unroll ? "[[unroll]]" : ""},
  818. };
  819. if (transpose) defines["TRANSPOSE"] = "1";
  820. std::string name = std::string(transpose ? "conv_transpose_2d": "conv2d")
  821. + (a_f16 ? "_f16" : "") + "_f32";
  822. string_to_spv(name + (unroll ? "_unroll" : ""), "conv2d_mm.comp", defines);
  823. #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
  824. if (unroll) {
  825. defines["COOPMAT2"] = "1";
  826. string_to_spv(name, "conv2d_mm.comp", defines, true, false, true);
  827. }
  828. #endif
  829. }
  830. }
  831. }
  832. string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
  833. string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
  834. string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
  835. string_to_spv("conv2d_dw_cwhn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
  836. string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  837. string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
  838. string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
  839. string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
  840. string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
  841. string_to_spv("ssm_scan_subgroup_f32", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
  842. string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}});
  843. string_to_spv("topk_moe_f32", "topk_moe.comp", {});
  844. for (auto &c : compiles) {
  845. c.wait();
  846. }
  847. }
  848. void write_output_files() {
  849. std::stringstream hdr = make_generic_stringstream();
  850. std::stringstream src = make_generic_stringstream();
  851. hdr << "#include <cstdint>\n\n";
  852. src << "#include \"" << basename(target_hpp) << "\"\n\n";
  853. std::sort(shader_fnames.begin(), shader_fnames.end());
  854. for (const auto& pair : shader_fnames) {
  855. const std::string& name = pair.first;
  856. #ifdef _WIN32
  857. std::string path = pair.second;
  858. std::replace(path.begin(), path.end(), '/', '\\' );
  859. #else
  860. const std::string& path = pair.second;
  861. #endif
  862. hdr << "extern const uint64_t " << name << "_len;\n";
  863. hdr << "extern const unsigned char " << name << "_data[];\n\n";
  864. if (input_filepath != "") {
  865. std::string data = read_binary_file(path);
  866. if (data.empty()) {
  867. continue;
  868. }
  869. src << "const uint64_t " << name << "_len = " << data.size() << ";\n";
  870. src << "const unsigned char " << name << "_data[" << data.size() << "] = {\n" << std::hex;
  871. auto bytes = reinterpret_cast<const uint8_t*>(data.data());
  872. for (size_t i = 0; i < data.size(); ++i) {
  873. src << "0x" << static_cast<int>(bytes[i]) << ",";
  874. if ((i + 1) % 12 == 0) src << "\n";
  875. }
  876. src << std::dec << "\n};\n\n";
  877. }
  878. }
  879. std::string suffixes[2] = {"_f32", "_f16"};
  880. for (std::string op : {"add", "sub", "mul", "div", "add_rms"}) {
  881. hdr << "extern const void * " << op << "_data[2][2][2][2];\n";
  882. hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n";
  883. std::string op_file = op == "add_rms" ? "add.comp" : std::string(op) + ".comp";
  884. if (basename(input_filepath) != op_file) {
  885. continue;
  886. }
  887. std::stringstream data = make_generic_stringstream();
  888. std::stringstream len = make_generic_stringstream();
  889. data << "const void * " << op << "_data[2][2][2][2] = ";
  890. len << "const uint64_t " << op << "_len[2][2][2][2] = ";
  891. for (uint32_t t0 = 0; t0 < 2; ++t0) {
  892. if (t0 == 0) {
  893. data << "{";
  894. len << "{";
  895. }
  896. for (uint32_t t1 = 0; t1 < 2; ++t1) {
  897. if (t1 == 0) {
  898. data << "{";
  899. len << "{";
  900. }
  901. for (uint32_t t2 = 0; t2 < 2; ++t2) {
  902. if (t2 == 0) {
  903. data << "{";
  904. len << "{";
  905. }
  906. for (uint32_t rte = 0; rte < 2; ++rte) {
  907. if (rte == 0) {
  908. data << "{";
  909. len << "{";
  910. }
  911. data << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : "");
  912. len << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : "");
  913. data << "_data,";
  914. len << "_len,";
  915. if (rte == 1) {
  916. data << "}, ";
  917. len << "}, ";
  918. }
  919. }
  920. if (t2 == 1) {
  921. data << "}, ";
  922. len << "}, ";
  923. }
  924. }
  925. if (t1 == 1) {
  926. data << "}, ";
  927. len << "}, ";
  928. }
  929. }
  930. if (t0 == 1) {
  931. data << "};\n";
  932. len << "};\n";
  933. }
  934. }
  935. src << data.str();
  936. src << len.str();
  937. }
  938. std::vector<std::string> btypes = {"f16", "f32"};
  939. #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
  940. btypes.push_back("q8_1");
  941. #endif
  942. for (const std::string& btype : btypes) {
  943. for (const auto& tname : type_names) {
  944. if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname)) {
  945. continue;
  946. }
  947. hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n";
  948. hdr << "extern const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3];\n";
  949. if (basename(input_filepath) == "mul_mat_vec.comp") {
  950. src << "const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n";
  951. src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n";
  952. }
  953. if (btype == "f16") {
  954. continue;
  955. }
  956. hdr << "extern const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3];\n";
  957. hdr << "extern const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3];\n";
  958. if (basename(input_filepath) == "mul_mat_vec.comp") {
  959. src << "const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n";
  960. src << "const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_len, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n";
  961. }
  962. }
  963. }
  964. if (input_filepath == "") {
  965. write_file_if_changed(target_hpp, hdr.str());
  966. }
  967. if (target_cpp != "") {
  968. write_binary_file(target_cpp, src.str());
  969. }
  970. }
  971. } // namespace
  972. int main(int argc, char** argv) {
  973. std::map<std::string, std::string> args;
  974. for (int i = 1; i < argc; ++i) {
  975. std::string arg = argv[i];
  976. if (arg.rfind("--", 0) == 0) {
  977. if (i + 1 < argc && argv[i + 1][0] != '-') {
  978. args[arg] = argv[i + 1];
  979. ++i;
  980. } else {
  981. args[arg] = "";
  982. }
  983. }
  984. }
  985. if (args.find("--glslc") != args.end()) {
  986. GLSLC = args["--glslc"]; // Path to glslc
  987. }
  988. if (args.find("--source") != args.end()) {
  989. input_filepath = args["--source"]; // The shader source file to compile
  990. }
  991. if (args.find("--output-dir") != args.end()) {
  992. output_dir = args["--output-dir"]; // Directory for containing SPIR-V output
  993. }
  994. if (args.find("--target-hpp") != args.end()) {
  995. target_hpp = args["--target-hpp"]; // Path to generated header file
  996. }
  997. if (args.find("--target-cpp") != args.end()) {
  998. target_cpp = args["--target-cpp"]; // Path to generated cpp file
  999. }
  1000. if (!directory_exists(output_dir)) {
  1001. if (!create_directory(output_dir)) {
  1002. std::cerr << "Error creating output directory: " << output_dir << "\n";
  1003. return EXIT_FAILURE;
  1004. }
  1005. }
  1006. process_shaders();
  1007. write_output_files();
  1008. return EXIT_SUCCESS;
  1009. }