1
0

ggml_vk_generate_shaders.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. #!/usr/bin/env python
  2. import logging
  3. import argparse
  4. import asyncio
  5. import os
  6. from tempfile import gettempdir
  7. logger = logging.getLogger("ggml-vk-generate-shaders")
  8. GLSLC = "glslc"
  9. type_names = [
  10. "f32",
  11. "f16",
  12. "q4_0",
  13. "q4_1",
  14. "q5_0",
  15. "q5_1",
  16. "q8_0",
  17. "q2_k",
  18. "q3_k",
  19. "q4_k",
  20. "q5_k",
  21. "q6_k",
  22. ]
  23. ASYNCIO_CONCURRENCY = 64
  24. input_dir = "vulkan-shaders"
  25. output_dir = gettempdir()
  26. lock = asyncio.Lock()
  27. shader_fnames = []
  28. async def string_to_spv(name, in_fname, defines, fp16=True):
  29. name = f"{name}{'_fp32' if not fp16 else ''}"
  30. out_fname = os.path.join(output_dir, f"{name}.spv")
  31. in_path = os.path.join(input_dir, in_fname)
  32. cmd = [GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", in_path, "-o", out_fname]
  33. cmd.extend([f"-D{key}={value}" for key, value in defines.items()])
  34. proc = await asyncio.create_subprocess_exec(*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
  35. stdout, stderr = await proc.communicate()
  36. stdout = stdout.decode()
  37. error = stderr.decode()
  38. if proc.returncode:
  39. cmd = " ".join(cmd)
  40. logger.error(f"cannot compile {name}\n\n{cmd}\n\n{error}")
  41. return
  42. async with lock:
  43. shader_fnames.append((name, out_fname))
  44. def matmul_shaders(tasks, fp16, matmul_id):
  45. if fp16:
  46. load_vec = "8"
  47. aligned_b_type_f32 = "mat2x4"
  48. aligned_b_type_f16 = "f16mat2x4"
  49. else:
  50. load_vec = "4"
  51. aligned_b_type_f32 = "vec4"
  52. aligned_b_type_f16 = "f16vec4"
  53. base_dict = {"FLOAT_TYPE": "float" if not fp16 else "float16_t"}
  54. shader_name = "matmul"
  55. if matmul_id:
  56. base_dict["MUL_MAT_ID"] = "1"
  57. shader_name = "matmul_id"
  58. if fp16:
  59. base_dict["FLOAT16"] = "1"
  60. # Shaders with f16 B_TYPE
  61. tasks.append(string_to_spv(f"{shader_name}_f32_f16", "mul_mm.comp", base_dict | {"DATA_A_F32": "1", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
  62. tasks.append(string_to_spv(f"{shader_name}_f32_f16_aligned", "mul_mm.comp", base_dict | {"DATA_A_F32": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "B_TYPE": aligned_b_type_f16, "D_TYPE": "float"}, fp16))
  63. tasks.append(string_to_spv(f"{shader_name}_f16", "mul_mm.comp", base_dict | {"DATA_A_F16": "1", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
  64. tasks.append(string_to_spv(f"{shader_name}_f16_aligned", "mul_mm.comp", base_dict | {"DATA_A_F16": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "B_TYPE": aligned_b_type_f16, "D_TYPE": "float"}, fp16))
  65. for tname in type_names:
  66. data_a_key = f"DATA_A_{tname.upper()}"
  67. load_vec_a = load_vec if tname in ("f32", "f16") else "2"
  68. tasks.append(string_to_spv(f"{shader_name}_{tname}_f32", "mul_mm.comp", base_dict | {data_a_key: "1", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
  69. tasks.append(string_to_spv(f"{shader_name}_{tname}_f32_aligned", "mul_mm.comp", base_dict | {data_a_key: "2", "LOAD_VEC_A": load_vec_a, "LOAD_VEC_B": load_vec, "B_TYPE": aligned_b_type_f32, "D_TYPE": "float"}, fp16))
  70. async def main():
  71. logger.info("ggml_vulkan: Generating and compiling shaders to SPIR-V")
  72. tasks = []
  73. base_dict = {"FLOAT_TYPE": "float"}
  74. for fp16 in (False, True):
  75. # MUL_MAT
  76. matmul_shaders(tasks, fp16, False)
  77. # MUL_MAT_ID
  78. matmul_shaders(tasks, fp16, True)
  79. for tname in type_names:
  80. # mul mat vec
  81. data_a_key = f"DATA_A_{tname.upper()}"
  82. shader = f"mul_mat_vec_{tname}.comp" if tname.endswith("_k") else "mul_mat_vec.comp"
  83. tasks.append(string_to_spv(f"mul_mat_vec_{tname}_f32_f32", shader, base_dict | {data_a_key: "1", "B_TYPE": "float", "D_TYPE": "float"}))
  84. tasks.append(string_to_spv(f"mul_mat_vec_{tname}_f16_f32", shader, base_dict | {data_a_key: "1", "B_TYPE": "float16_t", "D_TYPE": "float"}))
  85. tasks.append(string_to_spv(f"mul_mat_vec_id_{tname}_f32", shader, base_dict | {"MUL_MAT_ID": "1", data_a_key: "1", "B_TYPE": "float", "D_TYPE": "float"}))
  86. # Dequant shaders
  87. if tname != "f16":
  88. tasks.append(string_to_spv(f"dequant_{tname}", f"dequant_{tname}.comp", base_dict | {data_a_key: "1", "D_TYPE": "float16_t"}))
  89. # get_rows
  90. if not tname.endswith("_k"):
  91. shader = "get_rows.comp" if tname in ("f32", "f16") else "get_rows_quant.comp"
  92. if tname == "f16":
  93. tasks.append(string_to_spv(f"get_rows_{tname}", shader, {data_a_key: "1", "B_TYPE": "int", "D_TYPE": "float16_t", "OPTIMIZATION_ERROR_WORKAROUND": "1"}))
  94. else:
  95. tasks.append(string_to_spv(f"get_rows_{tname}", shader, {data_a_key: "1", "B_TYPE": "int", "D_TYPE": "float16_t"}))
  96. tasks.append(string_to_spv(f"get_rows_{tname}_f32", shader, {data_a_key: "1", "B_TYPE": "int", "D_TYPE": "float"}))
  97. tasks.append(string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}))
  98. tasks.append(string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}))
  99. # Norms
  100. tasks.append(string_to_spv("norm_f32", "norm.comp", base_dict | {"A_TYPE": "float", "D_TYPE": "float"}))
  101. tasks.append(string_to_spv("rms_norm_f32", "rms_norm.comp", base_dict | {"A_TYPE": "float", "D_TYPE": "float"}))
  102. tasks.append(string_to_spv("cpy_f32_f32", "copy.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
  103. tasks.append(string_to_spv("cpy_f32_f16", "copy.comp", {"A_TYPE": "float", "D_TYPE": "float16_t"}))
  104. tasks.append(string_to_spv("cpy_f16_f16", "copy.comp", {"A_TYPE": "float16_t", "D_TYPE": "float16_t", "OPTIMIZATION_ERROR_WORKAROUND": "1"}))
  105. tasks.append(string_to_spv("add_f32", "add.comp", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
  106. tasks.append(string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}))
  107. tasks.append(string_to_spv("mul_f32", "mul.comp", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
  108. tasks.append(string_to_spv("div_f32", "div.comp", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
  109. tasks.append(string_to_spv("scale_f32", "scale.comp", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
  110. tasks.append(string_to_spv("sqr_f32", "square.comp", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
  111. tasks.append(string_to_spv("clamp_f32", "clamp.comp", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
  112. tasks.append(string_to_spv("gelu_f32", "gelu.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
  113. tasks.append(string_to_spv("silu_f32", "silu.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
  114. tasks.append(string_to_spv("relu_f32", "relu.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
  115. tasks.append(string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
  116. tasks.append(string_to_spv("soft_max_f32", "soft_max.comp", base_dict | {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}))
  117. tasks.append(string_to_spv("soft_max_f32_f16", "soft_max.comp", base_dict | {"A_TYPE": "float", "B_TYPE": "float16_t", "D_TYPE": "float"}))
  118. tasks.append(string_to_spv("rope_norm_f32", "rope_norm.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
  119. tasks.append(string_to_spv("rope_norm_f16", "rope_norm.comp", {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
  120. tasks.append(string_to_spv("rope_neox_f32", "rope_neox.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
  121. tasks.append(string_to_spv("rope_neox_f16", "rope_neox.comp", {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
  122. tasks.append(string_to_spv("argsort_f32", "argsort.comp", {"A_TYPE": "float"}))
  123. tasks.append(string_to_spv("sum_rows_f32", "sum_rows.comp", base_dict | {"A_TYPE": "float", "D_TYPE": "float"}))
  124. # Helper to decorate tasks with semaphore acquisition.
  125. async def withSemaphore(sem, task):
  126. async with sem:
  127. return await task
  128. # Run tasks concurrently guarded by a concurrency limit.
  129. sem = asyncio.Semaphore(ASYNCIO_CONCURRENCY)
  130. await asyncio.gather(*(withSemaphore(sem, task) for task in tasks))
  131. with open("ggml-vulkan-shaders.hpp", "w") as f:
  132. f.write("#include <cstdint>\n\n")
  133. for name, path in sorted(shader_fnames):
  134. with open(path, "rb") as spv:
  135. counter = 0
  136. newline_counter = 0
  137. f.write(f"unsigned char {name}_data[] = {{\n")
  138. for val in spv.read():
  139. f.write(f"0x{val:02x},")
  140. newline_counter += 1
  141. counter += 1
  142. if newline_counter >= 12:
  143. newline_counter = 0
  144. f.write("\n")
  145. f.write("\n};\n")
  146. f.write(f"const uint64_t {name}_len = {counter};\n\n")
  147. os.remove(path)
  148. if __name__ == "__main__":
  149. parser = argparse.ArgumentParser(description="GGML Vulkan Shader Generator")
  150. parser.add_argument("--glslc", help="Path to glslc")
  151. parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
  152. args = parser.parse_args()
  153. logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
  154. if args.glslc:
  155. GLSLC = args.glslc
  156. asyncio.run(main())