vulkan-shaders-gen.cpp 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751
  1. #include <iostream>
  2. #include <fstream>
  3. #include <sstream>
  4. #include <string>
  5. #include <stdexcept>
  6. #include <array>
  7. #include <vector>
  8. #include <map>
  9. #include <thread>
  10. #include <mutex>
  11. #include <future>
  12. #include <queue>
  13. #include <condition_variable>
  14. #include <cstdio>
  15. #include <cstring>
  16. #include <cstdlib>
  17. #include <cassert>
  18. #include <algorithm>
  19. #include <sys/stat.h>
  20. #include <sys/types.h>
  21. #ifdef _WIN32
  22. #include <windows.h>
  23. #include <direct.h> // For _mkdir on Windows
  24. #else
  25. #include <unistd.h>
  26. #include <sys/wait.h>
  27. #include <fcntl.h>
  28. #endif
  29. #define ASYNCIO_CONCURRENCY 64
  30. std::mutex lock;
  31. std::vector<std::pair<std::string, std::string>> shader_fnames;
  32. std::string GLSLC = "glslc";
  33. std::string input_dir = "vulkan-shaders";
  34. std::string output_dir = "/tmp";
  35. std::string target_hpp = "ggml-vulkan-shaders.hpp";
  36. std::string target_cpp = "ggml-vulkan-shaders.cpp";
  37. bool no_clean = false;
  38. const std::vector<std::string> type_names = {
  39. "f32",
  40. "f16",
  41. "q4_0",
  42. "q4_1",
  43. "q5_0",
  44. "q5_1",
  45. "q8_0",
  46. "q2_k",
  47. "q3_k",
  48. "q4_k",
  49. "q5_k",
  50. "q6_k",
  51. "iq1_s",
  52. "iq1_m",
  53. "iq2_xxs",
  54. "iq2_xs",
  55. "iq2_s",
  56. "iq3_xxs",
  57. "iq3_s",
  58. "iq4_xs",
  59. "iq4_nl",
  60. "bf16",
  61. };
  62. namespace {
  63. void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
  64. #ifdef _WIN32
  65. HANDLE stdout_read, stdout_write;
  66. HANDLE stderr_read, stderr_write;
  67. SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE };
  68. if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) ||
  69. !SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) {
  70. throw std::runtime_error("Failed to create stdout pipe");
  71. }
  72. if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) ||
  73. !SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) {
  74. throw std::runtime_error("Failed to create stderr pipe");
  75. }
  76. PROCESS_INFORMATION pi;
  77. STARTUPINFOA si = {};
  78. si.cb = sizeof(STARTUPINFOA);
  79. si.dwFlags = STARTF_USESTDHANDLES;
  80. si.hStdOutput = stdout_write;
  81. si.hStdError = stderr_write;
  82. std::vector<char> cmd(command.begin(), command.end());
  83. cmd.push_back('\0');
  84. if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) {
  85. throw std::runtime_error("Failed to create process");
  86. }
  87. CloseHandle(stdout_write);
  88. CloseHandle(stderr_write);
  89. std::array<char, 128> buffer;
  90. DWORD bytes_read;
  91. while (ReadFile(stdout_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
  92. stdout_str.append(buffer.data(), bytes_read);
  93. }
  94. while (ReadFile(stderr_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
  95. stderr_str.append(buffer.data(), bytes_read);
  96. }
  97. CloseHandle(stdout_read);
  98. CloseHandle(stderr_read);
  99. WaitForSingleObject(pi.hProcess, INFINITE);
  100. CloseHandle(pi.hProcess);
  101. CloseHandle(pi.hThread);
  102. #else
  103. int stdout_pipe[2];
  104. int stderr_pipe[2];
  105. if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) {
  106. throw std::runtime_error("Failed to create pipes");
  107. }
  108. pid_t pid = fork();
  109. if (pid < 0) {
  110. throw std::runtime_error("Failed to fork process");
  111. }
  112. if (pid == 0) {
  113. close(stdout_pipe[0]);
  114. close(stderr_pipe[0]);
  115. dup2(stdout_pipe[1], STDOUT_FILENO);
  116. dup2(stderr_pipe[1], STDERR_FILENO);
  117. close(stdout_pipe[1]);
  118. close(stderr_pipe[1]);
  119. execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr);
  120. _exit(EXIT_FAILURE);
  121. } else {
  122. close(stdout_pipe[1]);
  123. close(stderr_pipe[1]);
  124. std::array<char, 128> buffer;
  125. ssize_t bytes_read;
  126. while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) {
  127. stdout_str.append(buffer.data(), bytes_read);
  128. }
  129. while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) {
  130. stderr_str.append(buffer.data(), bytes_read);
  131. }
  132. close(stdout_pipe[0]);
  133. close(stderr_pipe[0]);
  134. waitpid(pid, nullptr, 0);
  135. }
  136. #endif
  137. }
  138. bool directory_exists(const std::string& path) {
  139. struct stat info;
  140. if (stat(path.c_str(), &info) != 0) {
  141. return false; // Path doesn't exist or can't be accessed
  142. }
  143. return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory
  144. }
  145. bool create_directory(const std::string& path) {
  146. #ifdef _WIN32
  147. return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists
  148. #else
  149. return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions
  150. #endif
  151. }
  152. std::string to_uppercase(const std::string& input) {
  153. std::string result = input;
  154. for (char& c : result) {
  155. c = std::toupper(c);
  156. }
  157. return result;
  158. }
  159. bool string_starts_with(const std::string& str, const std::string& prefix) {
  160. if (prefix.size() > str.size()) {
  161. return false;
  162. }
  163. return std::equal(prefix.begin(), prefix.end(), str.begin());
  164. }
  165. bool string_ends_with(const std::string& str, const std::string& suffix) {
  166. if (suffix.size() > str.size()) {
  167. return false;
  168. }
  169. return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
  170. }
  171. static const char path_separator = '/';
  172. std::string join_paths(const std::string& path1, const std::string& path2) {
  173. return path1 + path_separator + path2;
  174. }
  175. std::string basename(const std::string &path) {
  176. return path.substr(path.find_last_of("/\\") + 1);
  177. }
  178. // variables to track number of compiles in progress
  179. static uint32_t compile_count = 0;
  180. static std::mutex compile_count_mutex;
  181. static std::condition_variable compile_count_cond;
  182. void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
  183. std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
  184. std::string out_fname = join_paths(output_dir, name + ".spv");
  185. std::string in_path = join_paths(input_dir, in_fname);
  186. std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2";
  187. // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734
  188. std::string opt_level = coopmat ? "" : "-O";
  189. #ifdef _WIN32
  190. std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""};
  191. #else
  192. std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_fname};
  193. #endif
  194. #ifdef GGML_VULKAN_SHADER_DEBUG_INFO
  195. cmd.push_back("-g");
  196. #endif
  197. for (const auto& define : defines) {
  198. cmd.push_back("-D" + define.first + "=" + define.second);
  199. }
  200. std::string command;
  201. for (const auto& part : cmd) {
  202. command += part + " ";
  203. }
  204. std::string stdout_str, stderr_str;
  205. try {
  206. // std::cout << "Executing command: ";
  207. // for (const auto& part : cmd) {
  208. // std::cout << part << " ";
  209. // }
  210. // std::cout << std::endl;
  211. execute_command(command, stdout_str, stderr_str);
  212. if (!stderr_str.empty()) {
  213. std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl;
  214. return;
  215. }
  216. std::lock_guard<std::mutex> guard(lock);
  217. shader_fnames.push_back(std::make_pair(name, out_fname));
  218. } catch (const std::exception& e) {
  219. std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
  220. }
  221. {
  222. std::lock_guard<std::mutex> guard(compile_count_mutex);
  223. assert(compile_count > 0);
  224. compile_count--;
  225. }
  226. compile_count_cond.notify_all();
  227. }
  228. std::map<std::string, std::string> merge_maps(const std::map<std::string, std::string>& a, const std::map<std::string, std::string>& b) {
  229. std::map<std::string, std::string> result = a;
  230. result.insert(b.begin(), b.end());
  231. return result;
  232. }
  233. static std::vector<std::future<void>> compiles;
  234. void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
  235. {
  236. // wait until fewer than N compiles are in progress.
  237. // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
  238. uint32_t N = 16;
  239. std::unique_lock<std::mutex> guard(compile_count_mutex);
  240. while (compile_count >= N) {
  241. compile_count_cond.wait(guard);
  242. }
  243. compile_count++;
  244. }
  245. compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc));
  246. }
  247. void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) {
  248. std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
  249. std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
  250. std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
  251. std::map<std::string, std::string> base_dict = {
  252. {"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
  253. };
  254. std::string shader_name = "matmul";
  255. if (matmul_id) {
  256. base_dict["MUL_MAT_ID"] = "1";
  257. shader_name = "matmul_id";
  258. }
  259. if (fp16) {
  260. base_dict["FLOAT16"] = "1";
  261. }
  262. base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
  263. if (coopmat) {
  264. base_dict["COOPMAT"] = "1";
  265. }
  266. const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
  267. auto const &FLOAT_TYPE = [&](const std::string &t) -> std::string {
  268. if (t == "bf16") {
  269. // scalar path promotes to float
  270. if (!coopmat && !coopmat2) {
  271. return "float";
  272. }
  273. return "bfloat16_t";
  274. }
  275. if (coopmat2 || fp16) {
  276. return "float16_t";
  277. }
  278. return "float";
  279. };
  280. // Shaders with f16 B_TYPE
  281. string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
  282. string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
  283. string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
  284. string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
  285. // bf16
  286. {
  287. std::string load_vec_a_unaligned = "1";
  288. // For aligned matmul loads
  289. std::string load_vec_a = coopmat2 ? "1" : "4";
  290. // scalar path promotes to float
  291. std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
  292. // If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader
  293. #if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
  294. if (!(coopmat || coopmat2))
  295. #endif
  296. {
  297. string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
  298. string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
  299. }
  300. }
  301. for (const auto& tname : type_names) {
  302. std::string load_vec_quant = "2";
  303. if ((tname == "q4_0") || (tname == "q4_1"))
  304. load_vec_quant = "8";
  305. else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
  306. load_vec_quant = "4";
  307. if (tname == "bf16") {
  308. continue;
  309. }
  310. std::string data_a_key = "DATA_A_" + to_uppercase(tname);
  311. // For unaligned, load one at a time for f32/f16, or two at a time for quants
  312. std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
  313. // For aligned matmul loads
  314. std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
  315. // don't generate f32 variants for coopmat2
  316. if (!coopmat2) {
  317. string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
  318. string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
  319. }
  320. if (tname != "f16" && tname != "f32") {
  321. string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
  322. string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
  323. }
  324. #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
  325. if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
  326. string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
  327. }
  328. #endif
  329. }
  330. }
  331. void process_shaders() {
  332. std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl;
  333. std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
  334. // matmul
  335. for (const auto& matmul_id : {false, true}) {
  336. // No coopmats
  337. // fp32
  338. matmul_shaders(false, matmul_id, false, false, false);
  339. // fp16, fp32acc and fp16acc
  340. matmul_shaders(true, matmul_id, false, false, false);
  341. matmul_shaders(true, matmul_id, false, false, true);
  342. #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
  343. // Coopmat, fp32acc and fp16acc
  344. matmul_shaders(true, matmul_id, true, false, false);
  345. matmul_shaders(true, matmul_id, true, false, true);
  346. #endif
  347. #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
  348. // Coopmat2, fp32acc and fp16acc
  349. matmul_shaders(true, matmul_id, false, true, false);
  350. matmul_shaders(true, matmul_id, false, true, true);
  351. #endif
  352. }
  353. // flash attention
  354. for (const auto& f16acc : {false, true}) {
  355. std::string acctype = f16acc ? "float16_t" : "float";
  356. std::string acctypev4 = f16acc ? "f16vec4" : "vec4";
  357. for (const auto& tname : type_names) {
  358. if (tname == "f32") {
  359. continue;
  360. }
  361. if (tname == "bf16") continue;
  362. #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
  363. if (tname == "f16") {
  364. string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
  365. merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
  366. } else {
  367. std::string data_a_key = "DATA_A_" + to_uppercase(tname);
  368. string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
  369. merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
  370. }
  371. #endif
  372. #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
  373. if (tname == "f16") {
  374. string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
  375. merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc);
  376. } else if (tname == "q4_0" || tname == "q8_0") {
  377. std::string data_a_key = "DATA_A_" + to_uppercase(tname);
  378. string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
  379. merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
  380. }
  381. #endif
  382. if (tname == "f16") {
  383. string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
  384. merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc);
  385. } else if (tname == "q4_0" || tname == "q8_0") {
  386. std::string data_a_key = "DATA_A_" + to_uppercase(tname);
  387. string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
  388. merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
  389. }
  390. }
  391. }
  392. for (const auto& tname : type_names) {
  393. // mul mat vec
  394. std::string data_a_key = "DATA_A_" + to_uppercase(tname);
  395. std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
  396. string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
  397. string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
  398. string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
  399. // Dequant shaders
  400. if (tname != "f16" && tname != "bf16") {
  401. string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
  402. }
  403. if (!string_ends_with(tname, "_k")) {
  404. shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp";
  405. if (tname == "f16") {
  406. string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
  407. } else {
  408. string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
  409. }
  410. string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
  411. }
  412. }
  413. string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
  414. string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
  415. string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
  416. // Norms
  417. string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  418. string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  419. string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  420. string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
  421. string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  422. string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  423. string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
  424. string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
  425. string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
  426. string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
  427. string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  428. string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
  429. string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
  430. string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
  431. string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
  432. for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
  433. string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  434. string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
  435. string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  436. }
  437. auto get_type_str = [](bool f16) {
  438. return f16 ? "float16_t" : "float";
  439. };
  440. auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
  441. std::string s;
  442. s += std::string(src0_f16 ? "_f16" : "_f32");
  443. s += std::string(src1_f16 ? "_f16" : "_f32");
  444. s += std::string(dst_f16 ? "_f16" : "_f32");
  445. return s;
  446. };
  447. for (std::string op : {"add", "sub", "mul", "div"}) {
  448. for (auto src0_f16 : {false, true}) {
  449. for (auto src1_f16 : {false, true}) {
  450. for (auto dst_f16 : {false, true}) {
  451. auto name = op + get_suffix(src0_f16, src1_f16, dst_f16);
  452. string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}});
  453. }
  454. }
  455. }
  456. }
  457. string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  458. string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  459. string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
  460. string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {});
  461. string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
  462. string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  463. string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  464. string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  465. string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  466. string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  467. string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  468. string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  469. string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  470. string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
  471. string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  472. string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
  473. string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
  474. string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}});
  475. string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
  476. string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  477. string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  478. string_to_spv("gelu_quick_f16", "gelu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  479. string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  480. string_to_spv("silu_f16", "silu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  481. string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  482. string_to_spv("relu_f16", "relu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  483. string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  484. string_to_spv("tanh_f16", "tanh.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  485. string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  486. string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  487. string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  488. string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  489. string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
  490. string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  491. string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
  492. string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
  493. string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
  494. string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  495. string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  496. string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
  497. string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  498. string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  499. string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
  500. string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  501. string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  502. string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
  503. string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
  504. string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
  505. string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
  506. string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
  507. string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
  508. string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  509. string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
  510. string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  511. string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
  512. string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
  513. string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  514. string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
  515. string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
  516. string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
  517. string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
  518. string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
  519. string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
  520. for (auto &c : compiles) {
  521. c.wait();
  522. }
  523. }
  524. void write_output_files() {
  525. FILE* hdr = fopen(target_hpp.c_str(), "w");
  526. FILE* src = fopen(target_cpp.c_str(), "w");
  527. fprintf(hdr, "#include <cstdint>\n\n");
  528. fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str());
  529. std::sort(shader_fnames.begin(), shader_fnames.end());
  530. for (const auto& pair : shader_fnames) {
  531. const std::string& name = pair.first;
  532. #ifdef _WIN32
  533. std::string path = pair.second;
  534. std::replace(path.begin(), path.end(), '/', '\\' );
  535. #else
  536. const std::string& path = pair.second;
  537. #endif
  538. FILE* spv = fopen(path.c_str(), "rb");
  539. if (!spv) {
  540. std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
  541. continue;
  542. }
  543. fseek(spv, 0, SEEK_END);
  544. size_t size = ftell(spv);
  545. fseek(spv, 0, SEEK_SET);
  546. std::vector<unsigned char> data(size);
  547. size_t read_size = fread(data.data(), 1, size, spv);
  548. fclose(spv);
  549. if (read_size != size) {
  550. std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
  551. continue;
  552. }
  553. fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size);
  554. fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size);
  555. fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size);
  556. for (size_t i = 0; i < size; ++i) {
  557. fprintf(src, "0x%02x,", data[i]);
  558. if ((i + 1) % 12 == 0) fprintf(src, "\n");
  559. }
  560. fprintf(src, "\n};\n\n");
  561. if (!no_clean) {
  562. std::remove(path.c_str());
  563. }
  564. }
  565. for (const char *op : {"add", "sub", "mul", "div"}) {
  566. fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op);
  567. fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op);
  568. fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op);
  569. fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op);
  570. }
  571. fclose(hdr);
  572. fclose(src);
  573. }
  574. }
  575. int main(int argc, char** argv) {
  576. std::map<std::string, std::string> args;
  577. for (int i = 1; i < argc; ++i) {
  578. std::string arg = argv[i];
  579. if (arg.rfind("--", 0) == 0) {
  580. if (i + 1 < argc && argv[i + 1][0] != '-') {
  581. args[arg] = argv[i + 1];
  582. ++i;
  583. } else {
  584. args[arg] = "";
  585. }
  586. }
  587. }
  588. if (args.find("--glslc") != args.end()) {
  589. GLSLC = args["--glslc"]; // Path to glslc
  590. }
  591. if (args.find("--input-dir") != args.end()) {
  592. input_dir = args["--input-dir"]; // Directory containing shader sources
  593. }
  594. if (args.find("--output-dir") != args.end()) {
  595. output_dir = args["--output-dir"]; // Directory for containing SPIR-V output
  596. }
  597. if (args.find("--target-hpp") != args.end()) {
  598. target_hpp = args["--target-hpp"]; // Path to generated header file
  599. }
  600. if (args.find("--target-cpp") != args.end()) {
  601. target_cpp = args["--target-cpp"]; // Path to generated cpp file
  602. }
  603. if (args.find("--no-clean") != args.end()) {
  604. no_clean = true; // Keep temporary SPIR-V files in output-dir after build
  605. }
  606. if (!directory_exists(input_dir)) {
  607. std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl;
  608. return EXIT_FAILURE;
  609. }
  610. if (!directory_exists(output_dir)) {
  611. if (!create_directory(output_dir)) {
  612. std::cerr << "Error creating output directory: " << output_dir << "\n";
  613. return EXIT_FAILURE;
  614. }
  615. }
  616. process_shaders();
  617. write_output_files();
  618. return EXIT_SUCCESS;
  619. }