vulkan-shaders-gen.cpp 58 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162
  1. #include <iostream>
  2. #include <fstream>
  3. #include <sstream>
  4. #include <string>
  5. #include <stdexcept>
  6. #include <array>
  7. #include <vector>
  8. #include <map>
  9. #include <thread>
  10. #include <mutex>
  11. #include <future>
  12. #include <queue>
  13. #include <condition_variable>
  14. #include <cstdio>
  15. #include <cstring>
  16. #include <cstdlib>
  17. #include <cassert>
  18. #include <algorithm>
  19. #include <sys/stat.h>
  20. #include <sys/types.h>
  21. #include <filesystem>
  22. #ifdef _WIN32
  23. #define NOMINMAX
  24. #include <windows.h>
  25. #include <direct.h> // For _mkdir on Windows
  26. #else
  27. #include <unistd.h>
  28. #include <sys/wait.h>
  29. #include <fcntl.h>
  30. #endif
  31. #define ASYNCIO_CONCURRENCY 64
  32. std::mutex lock;
  33. std::vector<std::pair<std::string, std::string>> shader_fnames;
  34. std::locale c_locale("C");
  35. std::string GLSLC = "glslc";
  36. std::string input_filepath = "";
  37. std::string output_dir = "/tmp";
  38. std::string target_hpp = "";
  39. std::string target_cpp = "";
  40. const std::vector<std::string> type_names = {
  41. "f32",
  42. "f16",
  43. "q4_0",
  44. "q4_1",
  45. "q5_0",
  46. "q5_1",
  47. "q8_0",
  48. "q2_k",
  49. "q3_k",
  50. "q4_k",
  51. "q5_k",
  52. "q6_k",
  53. "iq1_s",
  54. "iq1_m",
  55. "iq2_xxs",
  56. "iq2_xs",
  57. "iq2_s",
  58. "iq3_xxs",
  59. "iq3_s",
  60. "iq4_xs",
  61. "iq4_nl",
  62. "mxfp4",
  63. "bf16",
  64. };
  65. enum MatMulIdType {
  66. NONE,
  67. DEFAULT,
  68. SUBGROUP,
  69. };
  70. namespace {
  71. void execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
  72. #ifdef _WIN32
  73. HANDLE stdout_read, stdout_write;
  74. HANDLE stderr_read, stderr_write;
  75. SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE };
  76. if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) ||
  77. !SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) {
  78. throw std::runtime_error("Failed to create stdout pipe");
  79. }
  80. if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) ||
  81. !SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) {
  82. throw std::runtime_error("Failed to create stderr pipe");
  83. }
  84. PROCESS_INFORMATION pi;
  85. STARTUPINFOA si = {};
  86. si.cb = sizeof(STARTUPINFOA);
  87. si.dwFlags = STARTF_USESTDHANDLES;
  88. si.hStdOutput = stdout_write;
  89. si.hStdError = stderr_write;
  90. std::string cmd;
  91. for (const auto& part : command) {
  92. cmd += part + " ";
  93. }
  94. if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) {
  95. throw std::runtime_error("Failed to create process");
  96. }
  97. CloseHandle(stdout_write);
  98. CloseHandle(stderr_write);
  99. std::array<char, 128> buffer;
  100. DWORD bytes_read;
  101. while (ReadFile(stdout_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
  102. stdout_str.append(buffer.data(), bytes_read);
  103. }
  104. while (ReadFile(stderr_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
  105. stderr_str.append(buffer.data(), bytes_read);
  106. }
  107. CloseHandle(stdout_read);
  108. CloseHandle(stderr_read);
  109. WaitForSingleObject(pi.hProcess, INFINITE);
  110. CloseHandle(pi.hProcess);
  111. CloseHandle(pi.hThread);
  112. #else
  113. int stdout_pipe[2];
  114. int stderr_pipe[2];
  115. if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) {
  116. throw std::runtime_error("Failed to create pipes");
  117. }
  118. pid_t pid = fork();
  119. if (pid < 0) {
  120. throw std::runtime_error("Failed to fork process");
  121. }
  122. std::vector<char*> argv;
  123. for (std::string& part : command) {
  124. argv.push_back(part.data());
  125. }
  126. argv.push_back(nullptr);
  127. if (pid == 0) {
  128. close(stdout_pipe[0]);
  129. close(stderr_pipe[0]);
  130. dup2(stdout_pipe[1], STDOUT_FILENO);
  131. dup2(stderr_pipe[1], STDERR_FILENO);
  132. close(stdout_pipe[1]);
  133. close(stderr_pipe[1]);
  134. execvp(argv[0], argv.data());
  135. _exit(EXIT_FAILURE);
  136. } else {
  137. close(stdout_pipe[1]);
  138. close(stderr_pipe[1]);
  139. std::array<char, 128> buffer;
  140. ssize_t bytes_read;
  141. while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) {
  142. stdout_str.append(buffer.data(), bytes_read);
  143. }
  144. while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) {
  145. stderr_str.append(buffer.data(), bytes_read);
  146. }
  147. close(stdout_pipe[0]);
  148. close(stderr_pipe[0]);
  149. waitpid(pid, nullptr, 0);
  150. }
  151. #endif
  152. }
  153. bool directory_exists(const std::string& path) {
  154. struct stat info;
  155. if (stat(path.c_str(), &info) != 0) {
  156. return false; // Path doesn't exist or can't be accessed
  157. }
  158. return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory
  159. }
  160. bool create_directory(const std::string& path) {
  161. #ifdef _WIN32
  162. return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists
  163. #else
  164. return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions
  165. #endif
  166. }
  167. std::string to_uppercase(const std::string& input) {
  168. std::string result = input;
  169. for (char& c : result) {
  170. c = std::toupper(c);
  171. }
  172. return result;
  173. }
  174. bool string_starts_with(const std::string& str, const std::string& prefix) {
  175. if (prefix.size() > str.size()) {
  176. return false;
  177. }
  178. return std::equal(prefix.begin(), prefix.end(), str.begin());
  179. }
  180. bool string_ends_with(const std::string& str, const std::string& suffix) {
  181. if (suffix.size() > str.size()) {
  182. return false;
  183. }
  184. return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
  185. }
  186. bool is_quantized_type(const std::string& type_name) {
  187. return type_name != "f32" && type_name != "f16" && type_name != "bf16";
  188. }
  189. bool is_legacy_quant(const std::string& type_name) {
  190. return type_name == "q4_0" || type_name == "q4_1" || type_name == "q5_0" || type_name == "q5_1" || type_name == "q8_0";
  191. }
  192. bool is_k_quant(const std::string& type_name) {
  193. return string_ends_with(type_name, "_k");
  194. }
  195. bool is_iq_quant(const std::string& type_name) {
  196. return string_starts_with(type_name, "iq");
  197. }
  198. static const char path_separator = '/';
  199. std::string join_paths(const std::string& path1, const std::string& path2) {
  200. return path1 + path_separator + path2;
  201. }
  202. std::string basename(const std::string &path) {
  203. return path.substr(path.find_last_of("/\\") + 1);
  204. }
  205. std::stringstream make_generic_stringstream() {
  206. std::stringstream ss;
  207. ss.imbue(c_locale);
  208. return ss;
  209. }
  210. std::string read_binary_file(const std::string& path, bool may_not_exist = false) {
  211. FILE* f = fopen(path.c_str(), "rb");
  212. if (!f) {
  213. if (!may_not_exist) {
  214. std::cerr << "Error opening file: " << path << " (" << strerror(errno) << ")\n";
  215. }
  216. return {};
  217. }
  218. fseek(f, 0, SEEK_END);
  219. size_t size = ftell(f);
  220. fseek(f, 0, SEEK_SET);
  221. std::string data(size, '\0');
  222. size_t read_size = fread(data.data(), 1, size, f);
  223. fclose(f);
  224. if (read_size != size) {
  225. std::cerr << "Error reading file: " << path << " (" << strerror(errno) << ")\n";
  226. return {};
  227. }
  228. return data;
  229. }
  230. void write_binary_file(const std::string& path, const std::string& content) {
  231. FILE* f = fopen(path.c_str(), "wb");
  232. if (!f) {
  233. std::cerr << "Error opening file for writing: " << path << " (" << strerror(errno) << ")\n";
  234. return;
  235. }
  236. size_t write_size = fwrite(content.data(), 1, content.size(), f);
  237. fclose(f);
  238. if (write_size != content.size()) {
  239. std::cerr << "Error writing file: " << path << " (" << strerror(errno) << ")\n";
  240. return;
  241. }
  242. }
  243. void write_file_if_changed(const std::string& path, const std::string& content) {
  244. std::string existing = read_binary_file(path, true);
  245. if (existing != content) {
  246. write_binary_file(path, content);
  247. }
  248. }
  249. // variables to track number of compiles in progress
  250. static uint32_t compile_count = 0;
  251. static std::mutex compile_count_mutex;
  252. static std::condition_variable compile_count_cond;
  253. static bool generate_dep_file = true;
  254. void decrement_compile_count(uint32_t * count) {
  255. if (count) {
  256. std::lock_guard<std::mutex> guard(compile_count_mutex);
  257. assert(compile_count > 0);
  258. compile_count--;
  259. compile_count_cond.notify_all();
  260. }
  261. }
  262. using compile_count_guard = std::unique_ptr<uint32_t, decltype(&decrement_compile_count)>;
  263. compile_count_guard acquire_compile_slot() {
  264. // wait until fewer than N compiles are in progress.
  265. // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
  266. uint32_t N = std::max(1u, std::min(16u, std::thread::hardware_concurrency()));
  267. std::unique_lock<std::mutex> guard(compile_count_mutex);
  268. compile_count_cond.wait(guard, [N] { return compile_count < N; });
  269. compile_count++;
  270. return compile_count_guard(&compile_count, &decrement_compile_count);
  271. }
  272. void string_to_spv_func(std::string name, std::string in_path, std::string out_path, std::map<std::string, std::string> defines, bool coopmat, bool dep_file, compile_count_guard slot) {
  273. std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2";
  274. #ifdef _WIN32
  275. std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, "\"" + in_path + "\"", "-o", "\"" + out_path + "\""};
  276. #else
  277. std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, in_path, "-o", out_path};
  278. #endif
  279. // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734
  280. // disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344
  281. // disable spirv-opt for rope shaders for https://github.com/ggml-org/llama.cpp/issues/16860
  282. if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos) {
  283. cmd.push_back("-O");
  284. }
  285. if (dep_file) {
  286. cmd.push_back("-MD");
  287. cmd.push_back("-MF");
  288. #ifdef _WIN32
  289. cmd.push_back("\"" + target_cpp + ".d\"");
  290. #else
  291. cmd.push_back(target_cpp + ".d");
  292. #endif
  293. }
  294. #ifdef GGML_VULKAN_SHADER_DEBUG_INFO
  295. cmd.push_back("-g");
  296. #endif
  297. for (const auto& define : defines) {
  298. cmd.push_back("-D" + define.first + "=" + define.second);
  299. }
  300. std::string command;
  301. for (const auto& part : cmd) {
  302. command += part + " ";
  303. }
  304. std::string stdout_str, stderr_str;
  305. try {
  306. // std::cout << "Executing command: ";
  307. // for (const auto& part : cmd) {
  308. // std::cout << part << " ";
  309. // }
  310. // std::cout << std::endl;
  311. execute_command(cmd, stdout_str, stderr_str);
  312. if (!stderr_str.empty()) {
  313. std::cerr << "cannot compile " << name << "\n\n";
  314. for (const auto& part : cmd) {
  315. std::cerr << part << " ";
  316. }
  317. std::cerr << "\n\n" << stderr_str << std::endl;
  318. return;
  319. }
  320. if (dep_file) {
  321. // replace .spv output path with the embed .cpp path which is used as output in CMakeLists.txt
  322. std::string dep = read_binary_file(target_cpp + ".d", true);
  323. if (!dep.empty()) {
  324. size_t pos = dep.find(out_path);
  325. if (pos != std::string::npos) {
  326. dep.replace(pos, out_path.length(), target_cpp);
  327. }
  328. write_binary_file(target_cpp + ".d", dep);
  329. }
  330. }
  331. std::lock_guard<std::mutex> guard(lock);
  332. shader_fnames.push_back(std::make_pair(name, out_path));
  333. } catch (const std::exception& e) {
  334. std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
  335. }
  336. }
  337. std::map<std::string, std::string> merge_maps(const std::map<std::string, std::string>& a, const std::map<std::string, std::string>& b) {
  338. std::map<std::string, std::string> result = a;
  339. result.insert(b.begin(), b.end());
  340. return result;
  341. }
  342. static std::vector<std::future<void>> compiles;
  343. void string_to_spv(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
  344. name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
  345. std::string out_path = join_paths(output_dir, name + ".spv");
  346. if (input_filepath == "") {
  347. // No input source to compile, only generate header for all shaders
  348. shader_fnames.push_back(std::pair(name, out_path));
  349. return;
  350. } else if (basename(input_filepath) != source) {
  351. // Only compile shader variants matching the input filename
  352. return;
  353. }
  354. compile_count_guard slot = acquire_compile_slot();
  355. compiles.push_back(std::async(
  356. string_to_spv_func, name, input_filepath, out_path, defines, coopmat, generate_dep_file, std::move(slot)));
  357. // Don't write the same dep file from multiple processes
  358. generate_dep_file = false;
  359. }
  360. void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) {
  361. std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
  362. std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
  363. std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
  364. std::map<std::string, std::string> base_dict;
  365. std::string shader_name = "matmul";
  366. if (matmul_id_type == MatMulIdType::DEFAULT) {
  367. base_dict["MUL_MAT_ID"] = "1";
  368. shader_name = "matmul_id";
  369. } else if (matmul_id_type == MatMulIdType::SUBGROUP) {
  370. base_dict["MUL_MAT_ID"] = "1";
  371. base_dict["MUL_MAT_ID_USE_SUBGROUPS"] = "1";
  372. shader_name = "matmul_id_subgroup";
  373. }
  374. if (fp16) {
  375. base_dict["FLOAT16"] = "1";
  376. }
  377. base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float";
  378. base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2" : "vec2";
  379. if (f16acc) {
  380. base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
  381. }
  382. if (coopmat) {
  383. base_dict["COOPMAT"] = "1";
  384. }
  385. const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
  386. auto const &FLOAT_TYPE = [&](int vec, const std::string &t) -> std::string {
  387. switch (vec) {
  388. case 1:
  389. if (t == "bf16") {
  390. // scalar path promotes to float
  391. if (!coopmat && !coopmat2) {
  392. return "float";
  393. }
  394. return "bfloat16_t";
  395. }
  396. if (coopmat2 || fp16) {
  397. return "float16_t";
  398. }
  399. return "float";
  400. case 2:
  401. if (t == "bf16") {
  402. // scalar path promotes to float
  403. if (!coopmat && !coopmat2) {
  404. return "vec2";
  405. }
  406. return "bf16vec2";
  407. }
  408. if (coopmat2 || fp16) {
  409. return "f16vec2";
  410. }
  411. return "vec2";
  412. case 4:
  413. if (t == "bf16") {
  414. // scalar path promotes to float
  415. if (!coopmat && !coopmat2) {
  416. return "vec4";
  417. }
  418. return "bf16vec4";
  419. }
  420. if (coopmat2 || fp16) {
  421. return "f16vec4";
  422. }
  423. return "vec4";
  424. case 8:
  425. if (t == "bf16") {
  426. // scalar path promotes to float
  427. if (!coopmat && !coopmat2) {
  428. return "mat2x4";
  429. }
  430. throw std::runtime_error("bf16 vec8 not supported");
  431. }
  432. if (coopmat2 || fp16) {
  433. return "f16mat2x4";
  434. }
  435. return "mat2x4";
  436. default:
  437. throw std::runtime_error("invalid vector size");
  438. }
  439. };
  440. const std::map<std::string, std::string> float_type_dict_f16 = {
  441. {"FLOAT_TYPE", FLOAT_TYPE(1, "f16")},
  442. {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "f16")},
  443. {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "f16")},
  444. {"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, "f16")},
  445. };
  446. // Shaders with f16 B_TYPE
  447. string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
  448. string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
  449. string_to_spv(shader_name + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
  450. string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
  451. // bf16
  452. {
  453. // For aligned matmul loads
  454. std::string load_vec_a = coopmat2 ? "1" : "4";
  455. // scalar path promotes to float
  456. std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
  457. const std::map<std::string, std::string> float_type_dict_bf16 = {
  458. {"FLOAT_TYPE", FLOAT_TYPE(1, "bf16")},
  459. {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "bf16")},
  460. {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "bf16")},
  461. };
  462. // If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader
  463. #if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
  464. if (!(coopmat || coopmat2))
  465. #endif
  466. {
  467. string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
  468. string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
  469. }
  470. }
  471. for (const auto& tname : type_names) {
  472. std::string load_vec_quant = "2";
  473. if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
  474. load_vec_quant = "8";
  475. else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
  476. load_vec_quant = "4";
  477. if (tname == "bf16") {
  478. continue;
  479. }
  480. std::string data_a_key = "DATA_A_" + to_uppercase(tname);
  481. // For unaligned, load one at a time for f32/f16, or two at a time for quants
  482. std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
  483. // For aligned matmul loads
  484. std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
  485. const std::map<std::string, std::string> float_type_dict = {
  486. {"FLOAT_TYPE", FLOAT_TYPE(1, tname)},
  487. {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, tname)},
  488. {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, tname)},
  489. {"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, tname)},
  490. };
  491. // don't generate f32 variants for coopmat2
  492. if (!coopmat2) {
  493. string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
  494. string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
  495. }
  496. if (tname != "f16" && tname != "f32") {
  497. string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
  498. string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
  499. }
  500. #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
  501. // Integer dot mmq performs better with f32 accumulators
  502. if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {
  503. string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
  504. }
  505. #endif
  506. }
  507. }
  508. void process_shaders() {
  509. std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
  510. // matmul
  511. for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
  512. // No coopmats
  513. // fp32
  514. matmul_shaders(false, matmul_id_type, false, false, false);
  515. // fp16, fp32acc and fp16acc
  516. matmul_shaders(true, matmul_id_type, false, false, false);
  517. matmul_shaders(true, matmul_id_type, false, false, true);
  518. if (matmul_id_type != MatMulIdType::DEFAULT) {
  519. #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
  520. // Coopmat, fp32acc and fp16acc
  521. matmul_shaders(true, matmul_id_type, true, false, false);
  522. matmul_shaders(true, matmul_id_type, true, false, true);
  523. #endif
  524. #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
  525. // Coopmat2, fp32acc and fp16acc
  526. matmul_shaders(true, matmul_id_type, false, true, false);
  527. matmul_shaders(true, matmul_id_type, false, true, true);
  528. #endif
  529. }
  530. }
  531. // flash attention
  532. for (const auto& f16acc : {false, true}) {
  533. std::map<std::string, std::string> fa_base_dict = base_dict;
  534. fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
  535. fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4";
  536. if (f16acc) {
  537. fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
  538. }
  539. for (const auto& tname : type_names) {
  540. if (tname == "bf16") continue;
  541. #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
  542. if (tname == "f16") {
  543. string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
  544. merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc);
  545. } else {
  546. std::string data_a_key = "DATA_A_" + to_uppercase(tname);
  547. string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
  548. merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
  549. }
  550. #endif
  551. #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
  552. if (tname == "f16") {
  553. string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
  554. merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
  555. } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
  556. std::string data_a_key = "DATA_A_" + to_uppercase(tname);
  557. string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
  558. merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
  559. }
  560. #endif
  561. if (tname == "f16") {
  562. string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
  563. merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
  564. } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
  565. std::string data_a_key = "DATA_A_" + to_uppercase(tname);
  566. string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
  567. merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
  568. }
  569. }
  570. }
  571. for (const auto& tname : type_names) {
  572. // mul mat vec
  573. std::string data_a_key = "DATA_A_" + to_uppercase(tname);
  574. std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
  575. string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
  576. string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
  577. string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
  578. string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
  579. string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
  580. string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
  581. string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
  582. // mul mat vec with integer dot product
  583. #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
  584. if (is_legacy_quant(tname)) {
  585. string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
  586. string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
  587. string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
  588. }
  589. #endif
  590. // Dequant shaders
  591. if (tname != "f16" && tname != "bf16") {
  592. string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
  593. }
  594. shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp";
  595. if (tname == "f16") {
  596. string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
  597. } else {
  598. string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
  599. }
  600. string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
  601. }
  602. string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
  603. string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
  604. string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
  605. // Norms
  606. string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  607. string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  608. string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
  609. string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
  610. string_to_spv("rms_norm_mul_rope_f32_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float"}, {"RMS_NORM_ROPE_FUSION", "1"}}));
  611. string_to_spv("rms_norm_mul_rope_f32_f16_rte", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}, {"RTE16", "1"}}));
  612. string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
  613. string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  614. string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  615. string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
  616. string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
  617. string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
  618. string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
  619. string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  620. string_to_spv("contig_cpy_f32_i32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}});
  621. string_to_spv("contig_cpy_i32_f32", "contig_copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}});
  622. string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
  623. string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
  624. string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
  625. string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
  626. string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}});
  627. string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}});
  628. string_to_spv("cpy_transpose_16", "copy_transpose.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}});
  629. string_to_spv("cpy_transpose_32", "copy_transpose.comp", {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}});
  630. for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
  631. string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  632. string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
  633. string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  634. }
  635. for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
  636. string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  637. string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
  638. string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  639. string_to_spv("set_rows_" + t + "_i64_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
  640. }
  641. auto get_type_str = [](bool f16) {
  642. return f16 ? "float16_t" : "float";
  643. };
  644. auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
  645. std::string s;
  646. s += std::string(src0_f16 ? "_f16" : "_f32");
  647. s += std::string(src1_f16 ? "_f16" : "_f32");
  648. s += std::string(dst_f16 ? "_f16" : "_f32");
  649. return s;
  650. };
  651. for (std::string op : {"add", "sub", "mul", "div", "add_rms", }) {
  652. for (auto src0_f16 : {false, true}) {
  653. for (auto src1_f16 : {false, true}) {
  654. for (auto dst_f16 : {false, true}) {
  655. for (auto rte : {false, true}) {
  656. auto source = op == "add_rms" ? std::string("add") : op;
  657. auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
  658. auto add_rms = op == "add_rms" ? "1" : "0";
  659. string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}});
  660. }
  661. }
  662. }
  663. }
  664. }
  665. string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  666. string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  667. string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
  668. string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {});
  669. string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
  670. string_to_spv("quantize_q8_1_subgroup", "quantize_q8_1.comp", {{"USE_SUBGROUPS", "1"}});
  671. string_to_spv("quantize_q8_1_x4", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}});
  672. string_to_spv("quantize_q8_1_x4_subgroup", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}, {"USE_SUBGROUPS", "1"}});
  673. string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  674. string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  675. string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  676. string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  677. string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  678. string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  679. string_to_spv("sqrt_f32", "sqrt.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  680. string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  681. string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  682. string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  683. string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  684. string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
  685. string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
  686. string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}});
  687. string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
  688. for (auto rte : {false, true}) {
  689. std::string suffix = rte ? "_rte" : "";
  690. string_to_spv("exp_f16" + suffix, "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
  691. string_to_spv("exp_f32" + suffix, "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"} , {"RTE16", rte ? "1" : "0"}});
  692. string_to_spv("log_f16" + suffix, "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
  693. string_to_spv("log_f32" + suffix, "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
  694. }
  695. string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  696. string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  697. string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  698. string_to_spv("gelu_erf_f32", "gelu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  699. string_to_spv("gelu_quick_f16", "gelu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  700. string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  701. string_to_spv("silu_f16", "silu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  702. string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  703. string_to_spv("relu_f16", "relu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  704. string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  705. string_to_spv("neg_f16", "neg.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  706. string_to_spv("neg_f32", "neg.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  707. string_to_spv("tanh_f16", "tanh.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  708. string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  709. string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  710. string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  711. string_to_spv("hardsigmoid_f16","hardsigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  712. string_to_spv("hardsigmoid_f32","hardsigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  713. string_to_spv("hardswish_f16", "hardswish.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  714. string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  715. string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  716. string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  717. string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  718. string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  719. string_to_spv("add1_f16_f16", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
  720. string_to_spv("add1_f16_f32", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
  721. string_to_spv("add1_f32_f32", "add1.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  722. string_to_spv("arange_f32", "arange.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  723. string_to_spv("fill_f32", "fill.comp", {{"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  724. string_to_spv("step_f16", "step.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  725. string_to_spv("step_f32", "step.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  726. string_to_spv("round_f16", "round.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  727. string_to_spv("round_f32", "round.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  728. string_to_spv("ceil_f16", "ceil.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  729. string_to_spv("ceil_f32", "ceil.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  730. string_to_spv("floor_f16", "floor.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  731. string_to_spv("floor_f32", "floor.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  732. string_to_spv("trunc_f16", "trunc.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  733. string_to_spv("trunc_f32", "trunc.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  734. for (auto rte : {false, true}) {
  735. std::string suffix = rte ? "_rte" : "";
  736. string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
  737. string_to_spv("geglu_f32" + suffix, "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
  738. string_to_spv("reglu_f16" + suffix, "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
  739. string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
  740. string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
  741. string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
  742. string_to_spv("swiglu_oai_f16" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
  743. string_to_spv("swiglu_oai_f32" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
  744. string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
  745. string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
  746. string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
  747. string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
  748. }
  749. string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  750. string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
  751. string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  752. string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
  753. string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
  754. string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
  755. string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
  756. string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
  757. string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
  758. string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
  759. string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
  760. string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
  761. string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
  762. string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
  763. string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
  764. string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
  765. string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
  766. string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
  767. string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
  768. string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
  769. string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
  770. string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
  771. string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
  772. string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
  773. string_to_spv("topk_argsort_f32", "topk_argsort.comp", {{"A_TYPE", "float"}});
  774. string_to_spv("topk_nary_search_f32", "topk_nary_search.comp", {{"A_TYPE", "float"}});
  775. string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
  776. string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  777. string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
  778. string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  779. for (std::string dim_str : {"", "_3d"}) {
  780. for (bool bda : {false, true}) {
  781. std::string bda_str = bda ? "_bda" : "";
  782. std::string bda_def = bda ? "1" : "0";
  783. string_to_spv("im2col" + dim_str + "_f32" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"D_SIZE", "4"}, {"BDA", bda_def}}));
  784. string_to_spv("im2col" + dim_str + "_f32_f16" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"BDA", bda_def}}));
  785. string_to_spv("im2col" + dim_str + "_f32_f16_rte" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"RTE16", "1"}, {"BDA", bda_def}}));
  786. }
  787. }
  788. string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  789. string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
  790. string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  791. string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
  792. string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
  793. string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
  794. string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
  795. for (auto transpose : {false, true}) {
  796. for (auto unroll : {false, true}) {
  797. for (auto a_f16 : {false, true}) {
  798. std::map<std::string, std::string> defines = {
  799. {"A_TYPE", a_f16 ? "float16_t" : "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"},
  800. {"USE_COLLECTIVES", "1"}, {"UNROLL", unroll ? "[[unroll]]" : ""},
  801. };
  802. if (transpose) defines["TRANSPOSE"] = "1";
  803. std::string name = std::string(transpose ? "conv_transpose_2d": "conv2d")
  804. + (a_f16 ? "_f16" : "") + "_f32";
  805. string_to_spv(name + (unroll ? "_unroll" : ""), "conv2d_mm.comp", defines);
  806. #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
  807. if (unroll) {
  808. defines["COOPMAT2"] = "1";
  809. string_to_spv(name, "conv2d_mm.comp", defines, true, false, true);
  810. }
  811. #endif
  812. }
  813. }
  814. }
  815. string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
  816. string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
  817. string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
  818. string_to_spv("conv2d_dw_cwhn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
  819. string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  820. string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
  821. string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
  822. string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
  823. string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
  824. string_to_spv("ssm_scan_subgroup_f32", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
  825. string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}});
  826. string_to_spv("topk_moe_f32", "topk_moe.comp", {});
  827. for (auto &c : compiles) {
  828. c.wait();
  829. }
  830. }
  831. void write_output_files() {
  832. std::stringstream hdr = make_generic_stringstream();
  833. std::stringstream src = make_generic_stringstream();
  834. hdr << "#include <cstdint>\n\n";
  835. src << "#include \"" << basename(target_hpp) << "\"\n\n";
  836. std::sort(shader_fnames.begin(), shader_fnames.end());
  837. for (const auto& pair : shader_fnames) {
  838. const std::string& name = pair.first;
  839. #ifdef _WIN32
  840. std::string path = pair.second;
  841. std::replace(path.begin(), path.end(), '/', '\\' );
  842. #else
  843. const std::string& path = pair.second;
  844. #endif
  845. hdr << "extern const uint64_t " << name << "_len;\n";
  846. hdr << "extern const unsigned char " << name << "_data[];\n\n";
  847. if (input_filepath != "") {
  848. std::string data = read_binary_file(path);
  849. if (data.empty()) {
  850. continue;
  851. }
  852. src << "const uint64_t " << name << "_len = " << data.size() << ";\n";
  853. src << "const unsigned char " << name << "_data[" << data.size() << "] = {\n" << std::hex;
  854. auto bytes = reinterpret_cast<const uint8_t*>(data.data());
  855. for (size_t i = 0; i < data.size(); ++i) {
  856. src << "0x" << static_cast<int>(bytes[i]) << ",";
  857. if ((i + 1) % 12 == 0) src << "\n";
  858. }
  859. src << std::dec << "\n};\n\n";
  860. }
  861. }
  862. std::string suffixes[2] = {"_f32", "_f16"};
  863. for (std::string op : {"add", "sub", "mul", "div", "add_rms"}) {
  864. hdr << "extern const void * " << op << "_data[2][2][2][2];\n";
  865. hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n";
  866. std::string op_file = op == "add_rms" ? "add.comp" : std::string(op) + ".comp";
  867. if (basename(input_filepath) != op_file) {
  868. continue;
  869. }
  870. std::stringstream data = make_generic_stringstream();
  871. std::stringstream len = make_generic_stringstream();
  872. data << "const void * " << op << "_data[2][2][2][2] = ";
  873. len << "const uint64_t " << op << "_len[2][2][2][2] = ";
  874. for (uint32_t t0 = 0; t0 < 2; ++t0) {
  875. if (t0 == 0) {
  876. data << "{";
  877. len << "{";
  878. }
  879. for (uint32_t t1 = 0; t1 < 2; ++t1) {
  880. if (t1 == 0) {
  881. data << "{";
  882. len << "{";
  883. }
  884. for (uint32_t t2 = 0; t2 < 2; ++t2) {
  885. if (t2 == 0) {
  886. data << "{";
  887. len << "{";
  888. }
  889. for (uint32_t rte = 0; rte < 2; ++rte) {
  890. if (rte == 0) {
  891. data << "{";
  892. len << "{";
  893. }
  894. data << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : "");
  895. len << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : "");
  896. data << "_data,";
  897. len << "_len,";
  898. if (rte == 1) {
  899. data << "}, ";
  900. len << "}, ";
  901. }
  902. }
  903. if (t2 == 1) {
  904. data << "}, ";
  905. len << "}, ";
  906. }
  907. }
  908. if (t1 == 1) {
  909. data << "}, ";
  910. len << "}, ";
  911. }
  912. }
  913. if (t0 == 1) {
  914. data << "};\n";
  915. len << "};\n";
  916. }
  917. }
  918. src << data.str();
  919. src << len.str();
  920. }
  921. std::vector<std::string> btypes = {"f16", "f32"};
  922. #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
  923. btypes.push_back("q8_1");
  924. #endif
  925. for (const std::string& btype : btypes) {
  926. for (const auto& tname : type_names) {
  927. if (btype == "q8_1" && !is_legacy_quant(tname)) {
  928. continue;
  929. }
  930. hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n";
  931. hdr << "extern const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3];\n";
  932. if (basename(input_filepath) == "mul_mat_vec.comp") {
  933. src << "const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n";
  934. src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n";
  935. }
  936. }
  937. }
  938. if (input_filepath == "") {
  939. write_file_if_changed(target_hpp, hdr.str());
  940. }
  941. if (target_cpp != "") {
  942. write_binary_file(target_cpp, src.str());
  943. }
  944. }
  945. } // namespace
  946. int main(int argc, char** argv) {
  947. std::map<std::string, std::string> args;
  948. for (int i = 1; i < argc; ++i) {
  949. std::string arg = argv[i];
  950. if (arg.rfind("--", 0) == 0) {
  951. if (i + 1 < argc && argv[i + 1][0] != '-') {
  952. args[arg] = argv[i + 1];
  953. ++i;
  954. } else {
  955. args[arg] = "";
  956. }
  957. }
  958. }
  959. if (args.find("--glslc") != args.end()) {
  960. GLSLC = args["--glslc"]; // Path to glslc
  961. }
  962. if (args.find("--source") != args.end()) {
  963. input_filepath = args["--source"]; // The shader source file to compile
  964. }
  965. if (args.find("--output-dir") != args.end()) {
  966. output_dir = args["--output-dir"]; // Directory for containing SPIR-V output
  967. }
  968. if (args.find("--target-hpp") != args.end()) {
  969. target_hpp = args["--target-hpp"]; // Path to generated header file
  970. }
  971. if (args.find("--target-cpp") != args.end()) {
  972. target_cpp = args["--target-cpp"]; // Path to generated cpp file
  973. }
  974. if (!directory_exists(output_dir)) {
  975. if (!create_directory(output_dir)) {
  976. std::cerr << "Error creating output directory: " << output_dir << "\n";
  977. return EXIT_FAILURE;
  978. }
  979. }
  980. process_shaders();
  981. write_output_files();
  982. return EXIT_SUCCESS;
  983. }