flash_attn_cm2.comp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. #version 450
  2. #extension GL_EXT_control_flow_attributes : enable
  3. #extension GL_EXT_shader_16bit_storage : require
  4. #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
  5. #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
  6. #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
  7. #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
  8. #extension GL_KHR_memory_scope_semantics : enable
  9. #extension GL_KHR_cooperative_matrix : enable
  10. #extension GL_NV_cooperative_matrix2 : enable
  11. #extension GL_EXT_buffer_reference : enable
  12. #extension GL_KHR_shader_subgroup_ballot : enable
  13. #extension GL_KHR_shader_subgroup_vote : enable
  14. #extension GL_EXT_null_initializer : enable
  15. #include "types.comp"
  16. #include "dequant_funcs_cm2.comp"
  17. #include "flash_attn_base.comp"
  18. layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
  19. layout (binding = 1) readonly buffer K {uint8_t data_k[];};
  20. layout (binding = 2) readonly buffer V {uint8_t data_v[];};
  21. layout (binding = 3) readonly buffer M {uint8_t data_m[];};
  22. ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
  23. return max(x, y);
  24. }
  25. ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
  26. return x;
  27. }
  28. // Replace matrix elements >= numRows or numCols with 'replace'
  29. ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) {
  30. if (row >= numRows || col >= numCols) {
  31. return replace;
  32. }
  33. return elem;
  34. }
  35. ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem)
  36. {
  37. return exp(elem);
  38. }
  39. ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1)
  40. {
  41. return max(elem0, elem1);
  42. }
  43. #if defined(BLOCK_SIZE)
  44. #define DECODEFUNC , DEQUANTFUNC
  45. #else
  46. #define DECODEFUNC
  47. #endif
  48. // Store the output when doing grouped query attention.
  49. // Rows index by Q's dimension 2, and the first N rows are valid.
  50. D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
  51. {
  52. if (r < N && c < HSV) {
  53. uint32_t offset = (iq2 + r) * HSV + c;
  54. data_o[o_offset + offset] = D_TYPE(elem);
  55. }
  56. return elem;
  57. }
  58. void main() {
  59. #ifdef NEEDS_INIT_IQ_SHMEM
  60. init_iq_shmem(gl_WorkGroupSize);
  61. #endif
  62. init_indices();
  63. tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
  64. tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
  65. tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp);
  66. tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
  67. #if defined(BLOCK_SIZE)
  68. tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
  69. tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
  70. #endif
  71. tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);
  72. tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);
  73. tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV);
  74. // hint to the compiler that strides are aligned for the aligned variant of the shader
  75. if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
  76. {
  77. q_stride &= ~7;
  78. #if !defined(BLOCK_SIZE)
  79. k_stride &= ~7;
  80. v_stride &= ~7;
  81. #endif
  82. m_stride &= ~7;
  83. }
  84. tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
  85. tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
  86. tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
  87. coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q;
  88. coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;
  89. uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
  90. coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
  91. Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
  92. Qf16 *= float16_t(p.scale);
  93. coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
  94. coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
  95. // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
  96. const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
  97. L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
  98. M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
  99. coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
  100. // ALiBi
  101. if (p.max_bias > 0.0f) {
  102. coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
  103. }
  104. uint32_t m_offset = 0;
  105. if (p.nem2 != 1 || p.nem3 != 1) {
  106. m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
  107. }
  108. [[dont_unroll]]
  109. for (uint32_t j = start_j; j < end_j; ++j) {
  110. coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
  111. coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
  112. uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
  113. coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
  114. S = coopMatMulAdd(Qf16, K_T, S);
  115. if (p.logit_softcap != 0.0f) {
  116. [[unroll]]
  117. for (int k = 0; k < S.length(); ++k) {
  118. S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
  119. }
  120. }
  121. if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
  122. tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
  123. tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
  124. tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
  125. coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
  126. coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
  127. S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
  128. }
  129. // Clear padding elements to -inf, so they don't contribute to rowmax
  130. if (Clamp != 0 &&
  131. ((j + 1) * Bc > KV ||
  132. (i + 1) * Br > N)) {
  133. uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
  134. uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
  135. coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C);
  136. }
  137. coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;
  138. coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce);
  139. coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> Mold = M;
  140. // M = max(rowmax, Mold)
  141. // P = e^(S - M)
  142. // eM = e^(Mold - M)
  143. coopMatPerElementNV(M, rowmax, Max, Mold);
  144. coopMatPerElementNV(P, S - M, Exp);
  145. coopMatPerElementNV(eM, Mold - M, Exp);
  146. // Clear padding elements to 0, so they don't contribute to rowsum
  147. if (Clamp != 0 &&
  148. ((j + 1) * Bc > KV ||
  149. (i + 1) * Br > N)) {
  150. uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
  151. uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
  152. coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C);
  153. }
  154. coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P);
  155. // compute rowsum by multiplying by matrix of all ones.
  156. coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0);
  157. rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
  158. rowsum = coopMatMulAdd(P_A, One, rowsum);
  159. coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
  160. uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
  161. coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC);
  162. L = eM*L + rowsum;
  163. // This is the "diagonal" matrix in the paper, but since we do componentwise
  164. // multiply rather than matrix multiply it has the diagonal element smeared
  165. // across the row
  166. coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> eMdiag;
  167. // resize eM by using smear/reduce
  168. coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
  169. // multiply with fp16 accumulation, then add to O.
  170. coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
  171. PV = coopMatMulAdd(P_A, V, PV);
  172. O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(PV);
  173. }
  174. // If there is split_k, then the split_k resolve shader does the final
  175. // division by L. Store the intermediate O value and per-row m and L values.
  176. if (p.k_num > 1) {
  177. coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
  178. uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
  179. coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
  180. o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
  181. coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
  182. coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
  183. return;
  184. }
  185. coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Ldiag;
  186. // resize L by using smear/reduce
  187. coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
  188. if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
  189. coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> S;
  190. coopMatPerElementNV(S, S, perElemOpGetSink, iq2);
  191. coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Mr;
  192. // resize M by using smear/reduce
  193. coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);
  194. // O, Ldiag, Mr all have the same type so all element locations match
  195. [[unroll]] for (uint32_t i = 0; i < Ldiag.length(); ++i) {
  196. ACC_TYPE sink = S[i];
  197. ACC_TYPE ms = ACC_TYPE(1.0f);
  198. ACC_TYPE vs = ACC_TYPE(1.0f);
  199. if (sink > Mr[i]) {
  200. ms = exp(Mr[i] - sink);
  201. O[i] *= ms;
  202. } else {
  203. vs = exp(sink - Mr[i]);
  204. }
  205. Ldiag[i] = Ldiag[i]*ms + vs;
  206. }
  207. }
  208. [[unroll]]
  209. for (int k = 0; k < Ldiag.length(); ++k) {
  210. Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
  211. }
  212. O = Ldiag*O;
  213. #if defined(ACC_TYPE_MAX)
  214. [[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
  215. #endif
  216. uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
  217. coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
  218. if (p.gqa_ratio > 1) {
  219. coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
  220. } else {
  221. tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
  222. tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV);
  223. // permute dimensions
  224. tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
  225. coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV_pad), tensorViewPermute);
  226. }
  227. }