multi_add.comp 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. #version 450
  2. #extension GL_EXT_shader_16bit_storage : require
  3. #extension GL_EXT_nonuniform_qualifier : enable
  4. #extension GL_EXT_control_flow_attributes : require
  5. #if ADD_RMS
  6. #extension GL_KHR_shader_subgroup_arithmetic : enable
  7. #extension GL_KHR_shader_subgroup_basic : enable
  8. #endif
  9. #include "rte.glsl"
  10. #include "types.glsl"
  11. #include "utils.glsl"
  12. layout (push_constant) uniform parameter2
  13. {
  14. // shape for dst
  15. uint ne20; uint ne21; uint ne22; uint ne23;
  16. // strides for srcs+dst
  17. uint nb[12][4];
  18. uint rms_partials;
  19. } p;
  20. // Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
  21. // layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
  22. // layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
  23. layout (binding = 0) buffer A {A_TYPE data_a[];} a[];
  24. layout (binding = 0) buffer D {D_TYPE data_d[];} d[];
  25. layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[];
  26. layout(constant_id = 0) const uint num_srcs = 2;
  27. uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) {
  28. return i03*p.nb[s][3] + i02*p.nb[s][2] + i01*p.nb[s][1] + i00*p.nb[s][0];
  29. }
  30. uint dst_idx(uint i00, uint i01, uint i02, uint i03) {
  31. uint nb20 = p.nb[num_srcs][0];
  32. uint nb21 = p.nb[num_srcs][1];
  33. uint nb22 = p.nb[num_srcs][2];
  34. uint nb23 = p.nb[num_srcs][3];
  35. return i03*nb23 + i02*nb22 + i01*nb21 + i00*nb20;
  36. }
  37. uint get_idx() {
  38. return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
  39. }
  40. const uint num_threads = 256;
  41. layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
  42. #if ADD_RMS
  43. // XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
  44. shared FLOAT_TYPE sumsh[num_threads];
  45. #endif
  46. void main() {
  47. uint idx = get_idx();
  48. uint orig_idx = idx;
  49. uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23;
  50. // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
  51. const uint num_iter = 2;
  52. FLOAT_TYPE sum_sq = 0;
  53. [[unroll]] for (uint i = 0; i < num_iter; ++i) {
  54. if (idx >= ne) {
  55. continue;
  56. }
  57. uint i00, i01, i02, i03;
  58. get_indices(idx, i00, i01, i02, i03, p.ne20, p.ne21, p.ne22, p.ne23);
  59. FLOAT_TYPE sum = FLOAT_TYPE(0);
  60. [[unroll]] for (uint s = 0; s < num_srcs; ++s) {
  61. sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]);
  62. }
  63. sum_sq += sum*sum;
  64. d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
  65. idx += num_threads;
  66. }
  67. #if ADD_RMS
  68. if (p.rms_partials != 0) {
  69. // reduce the sum within each subgroup, then across subgroups
  70. const uint NumSubgroups = num_threads / gl_SubgroupSize;
  71. sum_sq = subgroupAdd(sum_sq);
  72. if (gl_SubgroupInvocationID == 0) {
  73. sumsh[gl_SubgroupID] = sum_sq;
  74. }
  75. barrier();
  76. [[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
  77. if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
  78. sum_sq += sumsh[gl_SubgroupID + s];
  79. sumsh[gl_SubgroupID] = sum_sq;
  80. }
  81. barrier();
  82. }
  83. if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
  84. partials[num_srcs + 1].partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
  85. }
  86. }
  87. #endif
  88. }