mul_mm.comp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. #version 450
  2. #extension GL_EXT_control_flow_attributes : enable
  3. #extension GL_EXT_shader_16bit_storage : require
  4. #ifdef FLOAT16
  5. #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
  6. #endif
  7. #if defined(DATA_A_IQ1_M)
  8. #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
  9. #endif
  10. #if defined(DATA_A_BF16) && defined(COOPMAT)
  11. #extension GL_EXT_bfloat16 : enable
  12. #endif
  13. #ifdef COOPMAT
  14. #extension GL_KHR_cooperative_matrix : enable
  15. #extension GL_KHR_memory_scope_semantics : enable
  16. #endif
  17. #if defined(COOPMAT) || defined(MUL_MAT_ID_USE_SUBGROUPS)
  18. #extension GL_KHR_shader_subgroup_basic : enable
  19. #extension GL_KHR_shader_subgroup_ballot : enable
  20. #endif
  21. #ifdef MUL_MAT_ID
  22. #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
  23. #endif
  24. #include "types.comp"
  25. #ifndef LOAD_VEC_A
  26. #define LOAD_VEC_A 1
  27. #endif
  28. #ifndef LOAD_VEC_B
  29. #define LOAD_VEC_B 1
  30. #endif
  31. #if !defined(TO_FLOAT_TYPE)
  32. #define TO_FLOAT_TYPE FLOAT_TYPE
  33. #endif
  34. layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
  35. layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
  36. #if defined(A_TYPE_PACKED16)
  37. layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
  38. #endif
  39. #if defined(A_TYPE_PACKED32)
  40. layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
  41. #endif
  42. layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
  43. layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
  44. #ifdef MUL_MAT_ID
  45. layout (binding = 3) readonly buffer IDS {int data_ids[];};
  46. #endif
  47. layout (push_constant) uniform parameter
  48. {
  49. uint M;
  50. uint N;
  51. uint K;
  52. uint stride_a;
  53. uint stride_b;
  54. uint stride_d;
  55. uint batch_stride_a;
  56. uint batch_stride_b;
  57. uint batch_stride_d;
  58. #ifdef MUL_MAT_ID
  59. uint nei0;
  60. uint nei1;
  61. uint nbi1;
  62. uint ne11;
  63. #else
  64. uint k_split;
  65. uint ne02;
  66. uint ne12;
  67. uint broadcast2;
  68. uint broadcast3;
  69. #endif
  70. } p;
  71. layout (constant_id = 0) const uint BLOCK_SIZE = 64;
  72. layout (constant_id = 1) const uint BM = 64;
  73. layout (constant_id = 2) const uint BN = 64;
  74. layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
  75. layout (constant_id = 4) const uint WM = 32;
  76. layout (constant_id = 5) const uint WN = 32;
  77. layout (constant_id = 6) const uint WMITER = 2;
  78. layout (constant_id = 7) const uint TM = 4;
  79. layout (constant_id = 8) const uint TN = 2;
  80. layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
  81. layout (constant_id = 10) const uint WARP = 32;
  82. #ifdef COOPMAT
  83. #define SHMEM_STRIDE (BK + 8)
  84. #else
  85. #define SHMEM_STRIDE (BK + 1)
  86. #endif
  87. shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
  88. shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
  89. #define NUM_WARPS (BLOCK_SIZE / WARP)
  90. #ifdef MUL_MAT_ID
  91. shared u16vec2 row_ids[BN];
  92. uint _ne1;
  93. #ifdef MUL_MAT_ID_USE_SUBGROUPS
  94. shared uvec4 ballots_sh[NUM_WARPS];
  95. void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
  96. _ne1 = 0;
  97. uint num_elements = p.nei1 * p.nei0;
  98. uint nei0shift = findLSB(p.nei0);
  99. uint ids[16];
  100. uint iter = 0;
  101. for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
  102. // prefetch up to 16 elements
  103. if (iter == 0) {
  104. [[unroll]] for (uint k = 0; k < 16; ++k) {
  105. uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
  106. bool in_range = i < num_elements;
  107. uint ii1;
  108. if (nei0_is_pow2) {
  109. ii1 = i >> nei0shift;
  110. } else {
  111. ii1 = i / p.nei0;
  112. }
  113. uint ii0 = i - ii1 * p.nei0;
  114. ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
  115. }
  116. }
  117. uint i = j + gl_LocalInvocationIndex;
  118. bool in_range = i < num_elements;
  119. uint ii1;
  120. if (nei0_is_pow2) {
  121. ii1 = i >> nei0shift;
  122. } else {
  123. ii1 = i / p.nei0;
  124. }
  125. uint ii0 = i - ii1 * p.nei0;
  126. uint id = ids[iter++];
  127. uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
  128. ballots_sh[gl_SubgroupID] = ballot;
  129. barrier();
  130. uint subgroup_base = 0;
  131. uint total = 0;
  132. for (uint k = 0; k < gl_NumSubgroups; ++k) {
  133. if (k == gl_SubgroupID) {
  134. subgroup_base = total;
  135. }
  136. total += subgroupBallotBitCount(ballots_sh[k]);
  137. }
  138. barrier();
  139. uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
  140. if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
  141. row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
  142. }
  143. _ne1 += total;
  144. iter &= 15;
  145. if (_ne1 >= (ic + 1) * BN) {
  146. break;
  147. }
  148. }
  149. barrier();
  150. }
  151. #endif // MUL_MAT_ID_USE_SUBGROUPS
  152. #endif // MUL_MAT_ID
  153. #ifdef COOPMAT
  154. shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
  155. #endif
  156. #include "mul_mm_funcs.comp"
  157. void main() {
  158. #ifdef NEEDS_INIT_IQ_SHMEM
  159. init_iq_shmem(gl_WorkGroupSize);
  160. #endif
  161. #ifdef MUL_MAT_ID
  162. const uint expert_idx = gl_GlobalInvocationID.z;
  163. #else
  164. const uint batch_idx = gl_GlobalInvocationID.z;
  165. const uint i13 = batch_idx / p.ne12;
  166. const uint i12 = batch_idx % p.ne12;
  167. const uint i03 = i13 / p.broadcast3;
  168. const uint i02 = i12 / p.broadcast2;
  169. const uint batch_idx_a = i03 * p.ne02 + i02;
  170. #endif
  171. const uint blocks_m = (p.M + BM - 1) / BM;
  172. const uint ir = gl_WorkGroupID.x % blocks_m;
  173. const uint ik = gl_WorkGroupID.x / blocks_m;
  174. const uint ic = gl_WorkGroupID.y;
  175. const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
  176. const uint WSUBM = WM / WMITER;
  177. const uint WSUBN = WN / WNITER;
  178. #ifdef COOPMAT
  179. const uint warp_i = gl_SubgroupID;
  180. const uint tiw = gl_SubgroupInvocationID;
  181. const uint cms_per_row = WM / TM;
  182. const uint cms_per_col = WN / TN;
  183. const uint storestride = WARP / TM;
  184. const uint store_r = tiw % TM;
  185. const uint store_c = tiw / TM;
  186. #else
  187. const uint warp_i = gl_LocalInvocationID.x / WARP;
  188. const uint tiw = gl_LocalInvocationID.x % WARP;
  189. const uint tiwr = tiw % (WSUBM / TM);
  190. const uint tiwc = tiw / (WSUBM / TM);
  191. #endif
  192. const uint warp_r = warp_i % (BM / WM);
  193. const uint warp_c = warp_i / (BM / WM);
  194. const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
  195. const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
  196. const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
  197. const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
  198. const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK;
  199. const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
  200. #ifdef MUL_MAT_ID
  201. #ifdef MUL_MAT_ID_USE_SUBGROUPS
  202. if (bitCount(p.nei0) == 1) {
  203. load_row_ids(expert_idx, true, ic);
  204. } else {
  205. load_row_ids(expert_idx, false, ic);
  206. }
  207. #else
  208. _ne1 = 0;
  209. for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
  210. for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
  211. if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
  212. if (_ne1 >= ic * BN) {
  213. row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
  214. }
  215. _ne1++;
  216. }
  217. }
  218. }
  219. barrier();
  220. #endif
  221. // Workgroup has no work
  222. if (ic * BN >= _ne1) return;
  223. #endif
  224. #ifdef MUL_MAT_ID
  225. const uint start_k = 0;
  226. const uint end_k = p.K;
  227. #else
  228. const uint start_k = ik * p.k_split;
  229. const uint end_k = min(p.K, (ik + 1) * p.k_split);
  230. #endif
  231. uint pos_a = (
  232. #ifdef MUL_MAT_ID
  233. expert_idx * p.batch_stride_a +
  234. #else
  235. batch_idx_a * p.batch_stride_a +
  236. #endif
  237. ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
  238. #ifdef MUL_MAT_ID
  239. uint pos_b = 0;
  240. #else
  241. uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
  242. #endif
  243. #ifdef COOPMAT
  244. coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
  245. coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
  246. coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
  247. [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
  248. sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
  249. }
  250. #else
  251. ACC_TYPE sums[WMITER * TM * WNITER * TN];
  252. FLOAT_TYPE cache_a[WMITER * TM];
  253. FLOAT_TYPE cache_b[TN];
  254. [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
  255. sums[i] = ACC_TYPE(0.0f);
  256. }
  257. #endif
  258. for (uint block = start_k; block < end_k; block += BK) {
  259. [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
  260. load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block + loadr_a, end_k);
  261. }
  262. [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
  263. #if !defined(MUL_MAT_ID)
  264. load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block + loadr_b, end_k);
  265. #else
  266. load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic, _ne1, block + loadr_b, end_k);
  267. #endif
  268. }
  269. barrier();
  270. pos_a += BK / LOAD_VEC_A;
  271. pos_b += BK / LOAD_VEC_B;
  272. #ifdef COOPMAT
  273. [[unroll]] for (uint i = 0; i < BK; i += TK) {
  274. [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
  275. // Load from shared into cache
  276. coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
  277. [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
  278. coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
  279. sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);
  280. }
  281. }
  282. }
  283. #else
  284. [[unroll]] for (uint i = 0; i < BK; i++) {
  285. // Load from shared into cache
  286. [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
  287. [[unroll]] for (uint j = 0; j < TM; j++) {
  288. cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
  289. }
  290. }
  291. [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
  292. [[unroll]] for (uint j = 0; j < TN; j++) {
  293. cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
  294. }
  295. [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
  296. [[unroll]] for (uint cc = 0; cc < TN; cc++) {
  297. [[unroll]] for (uint cr = 0; cr < TM; cr++) {
  298. const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
  299. sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]);
  300. }
  301. }
  302. }
  303. }
  304. }
  305. #endif
  306. barrier();
  307. }
  308. #if defined(ACC_TYPE_MAX)
  309. #ifdef COOPMAT
  310. [[unroll]] for (uint j = 0; j < cms_per_row * cms_per_col; j++) {
  311. [[unroll]] for (uint i = 0; i < sums[j].length(); ++i) {
  312. sums[j][i] = clamp(sums[j][i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
  313. }
  314. }
  315. #else
  316. [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
  317. sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
  318. }
  319. #endif
  320. #endif
  321. const uint dr = ir * BM + warp_r * WM;
  322. const uint dc = ic * BN + warp_c * WN;
  323. #ifndef MUL_MAT_ID
  324. const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
  325. #endif
  326. #ifdef COOPMAT
  327. #ifdef MUL_MAT_ID
  328. [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
  329. [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
  330. coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
  331. [[unroll]] for (uint col = 0; col < TN; col += storestride) {
  332. const uint row_i = dc + cm_col * TN + col + store_c;
  333. if (row_i >= _ne1) break;
  334. const u16vec2 row_idx = row_ids[row_i - ic * BN];
  335. if (dr + cm_row * TM + store_r < p.M) {
  336. data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
  337. }
  338. }
  339. }
  340. }
  341. #else
  342. const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
  343. [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
  344. [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
  345. const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
  346. if (is_aligned && is_in_bounds) {
  347. // Full coopMat is within bounds and stride_d is aligned with 16B
  348. coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
  349. coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
  350. } else if (is_in_bounds) {
  351. // Full coopMat is within bounds, but stride_d is not aligned
  352. coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
  353. [[unroll]] for (uint col = 0; col < TN; col += storestride) {
  354. data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
  355. }
  356. } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
  357. // Partial coopMat is within bounds
  358. coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
  359. [[unroll]] for (uint col = 0; col < TN; col += storestride) {
  360. if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
  361. data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
  362. }
  363. }
  364. }
  365. }
  366. }
  367. #endif // MUL_MAT_ID
  368. #else
  369. [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
  370. [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
  371. const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
  372. const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
  373. [[unroll]] for (uint cc = 0; cc < TN; cc++) {
  374. #ifdef MUL_MAT_ID
  375. const uint row_i = dc_warp + cc;
  376. if (row_i >= _ne1) break;
  377. const u16vec2 row_idx = row_ids[row_i - ic * BN];
  378. #endif // MUL_MAT_ID
  379. [[unroll]] for (uint cr = 0; cr < TM; cr++) {
  380. #ifdef MUL_MAT_ID
  381. if (dr_warp + cr < p.M) {
  382. data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
  383. }
  384. #else
  385. if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
  386. data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
  387. }
  388. #endif // MUL_MAT_ID
  389. }
  390. }
  391. }
  392. }
  393. #endif // COOPMAT
  394. }