fattn.cu 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. #include "common.cuh"
  2. #include "fattn-common.cuh"
  3. #include "fattn-tile-f16.cuh"
  4. #include "fattn-tile-f32.cuh"
  5. #include "fattn-vec-f16.cuh"
  6. #include "fattn-vec-f32.cuh"
  7. #include "fattn-wmma-f16.cuh"
  8. #include "fattn.cuh"
  9. #include <cstdint>
  10. static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  11. const ggml_tensor * KQV = dst;
  12. const ggml_tensor * Q = dst->src[0];
  13. const int32_t precision = KQV->op_params[2];
  14. if (precision != GGML_PREC_DEFAULT) {
  15. if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
  16. constexpr int cols_per_block = 16;
  17. switch (Q->ne[0]) {
  18. case 64:
  19. ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
  20. break;
  21. case 80:
  22. ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
  23. break;
  24. case 96:
  25. ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
  26. break;
  27. case 112:
  28. ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
  29. break;
  30. case 128:
  31. ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
  32. break;
  33. case 256:
  34. ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
  35. break;
  36. default:
  37. GGML_ASSERT(false);
  38. break;
  39. }
  40. } else {
  41. constexpr int cols_per_block = 32;
  42. switch (Q->ne[0]) {
  43. case 64:
  44. ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
  45. break;
  46. case 80:
  47. ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
  48. break;
  49. case 96:
  50. ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
  51. break;
  52. case 112:
  53. ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
  54. break;
  55. case 128:
  56. ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
  57. break;
  58. // case 256:
  59. // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
  60. // break;
  61. default:
  62. GGML_ASSERT(false);
  63. break;
  64. }
  65. }
  66. return;
  67. }
  68. if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
  69. constexpr int cols_per_block = 8;
  70. switch (Q->ne[0]) {
  71. case 64:
  72. ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
  73. break;
  74. case 96:
  75. ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
  76. break;
  77. case 128:
  78. ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
  79. break;
  80. case 256:
  81. ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
  82. break;
  83. default:
  84. GGML_ASSERT(false);
  85. break;
  86. }
  87. return;
  88. }
  89. if (Q->ne[1] <= 32) {
  90. constexpr int cols_per_block = 16;
  91. switch (Q->ne[0]) {
  92. case 64:
  93. ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
  94. break;
  95. case 80:
  96. ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
  97. break;
  98. case 96:
  99. ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
  100. break;
  101. case 112:
  102. ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
  103. break;
  104. case 128:
  105. ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
  106. break;
  107. case 256:
  108. ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
  109. break;
  110. default:
  111. GGML_ASSERT(false);
  112. break;
  113. }
  114. return;
  115. }
  116. constexpr int cols_per_block = 32;
  117. switch (Q->ne[0]) {
  118. case 64:
  119. ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
  120. break;
  121. case 80:
  122. ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
  123. break;
  124. case 96:
  125. ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
  126. break;
  127. case 112:
  128. ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
  129. break;
  130. case 128:
  131. ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
  132. break;
  133. case 256:
  134. ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
  135. break;
  136. default:
  137. GGML_ASSERT(false);
  138. break;
  139. }
  140. }
  141. #define FATTN_VEC_F16_CASE(D, type_K, type_V) \
  142. if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
  143. ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
  144. return; \
  145. } \
  146. static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  147. ggml_tensor * Q = dst->src[1];
  148. ggml_tensor * K = dst->src[1];
  149. ggml_tensor * V = dst->src[2];
  150. #ifdef GGML_CUDA_FA_ALL_QUANTS
  151. FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
  152. FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
  153. FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
  154. FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
  155. FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
  156. FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 )
  157. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
  158. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
  159. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
  160. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
  161. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
  162. FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0)
  163. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
  164. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
  165. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
  166. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
  167. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
  168. FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1)
  169. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
  170. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
  171. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
  172. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
  173. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
  174. FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0)
  175. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
  176. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
  177. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
  178. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
  179. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
  180. FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
  181. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
  182. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
  183. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
  184. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
  185. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
  186. FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
  187. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
  188. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
  189. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
  190. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
  191. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
  192. FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
  193. FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
  194. #else
  195. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
  196. FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
  197. FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
  198. FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
  199. FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
  200. #endif // GGML_CUDA_FA_ALL_QUANTS
  201. on_no_fattn_vec_case(Q->ne[0]);
  202. }
  203. #define FATTN_VEC_F32_CASE(D, type_K, type_V) \
  204. if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
  205. ggml_cuda_flash_attn_ext_vec_f32_case<D, type_K, type_V>(ctx, dst); \
  206. return; \
  207. } \
  208. static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  209. ggml_tensor * Q = dst->src[1];
  210. ggml_tensor * K = dst->src[1];
  211. ggml_tensor * V = dst->src[2];
  212. #ifdef GGML_CUDA_FA_ALL_QUANTS
  213. FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
  214. FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
  215. FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
  216. FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
  217. FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
  218. FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
  219. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
  220. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
  221. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
  222. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
  223. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
  224. FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0)
  225. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
  226. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
  227. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
  228. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
  229. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
  230. FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1)
  231. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
  232. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
  233. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
  234. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
  235. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
  236. FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0)
  237. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
  238. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
  239. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
  240. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
  241. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
  242. FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
  243. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
  244. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
  245. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
  246. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
  247. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
  248. FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
  249. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
  250. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
  251. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
  252. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
  253. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
  254. FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
  255. FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
  256. #else
  257. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
  258. FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
  259. FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
  260. FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
  261. FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
  262. #endif // GGML_CUDA_FA_ALL_QUANTS
  263. on_no_fattn_vec_case(Q->ne[0]);
  264. }
  265. void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  266. const ggml_tensor * KQV = dst;
  267. const ggml_tensor * Q = dst->src[0];
  268. ggml_cuda_set_device(ctx.device);
  269. const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
  270. const int32_t precision = KQV->op_params[2];
  271. // On AMD the tile kernels perform poorly, use the vec kernel instead:
  272. if (cc >= CC_OFFSET_AMD) {
  273. if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
  274. ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
  275. } else {
  276. ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
  277. }
  278. return;
  279. }
  280. if (!fast_fp16_available(cc)) {
  281. if (Q->ne[1] <= 8) {
  282. ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
  283. } else {
  284. ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
  285. }
  286. return;
  287. }
  288. if (!fp16_mma_available(cc)) {
  289. if (Q->ne[1] <= 8) {
  290. ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
  291. } else {
  292. ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
  293. }
  294. return;
  295. }
  296. if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
  297. if (precision == GGML_PREC_DEFAULT) {
  298. ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
  299. return;
  300. } else if(Q->ne[0] <= 128) {
  301. ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
  302. return;
  303. }
  304. }
  305. ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
  306. }