fattn-vec.cuh 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593
  1. #include "common.cuh"
  2. #include "fattn-common.cuh"
  3. static int ggml_cuda_fattn_vec_get_nthreads_host(const int cc) {
  4. return 128;
  5. GGML_UNUSED(cc);
  6. }
  7. static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() {
  8. return 128;
  9. }
  10. // Currenlty llvm with the amdgcn target dose not support unrolling loops
  11. // that contain a break that can not be resolved at compile time.
  12. #ifdef __clang__
  13. #pragma clang diagnostic push
  14. #pragma clang diagnostic ignored "-Wpass-failed"
  15. #endif // __clang__
  16. template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
  17. __launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 1)
  18. static __global__ void flash_attn_ext_vec(
  19. const char * __restrict__ Q,
  20. const char * __restrict__ K,
  21. const char * __restrict__ V,
  22. const char * __restrict__ mask,
  23. const char * __restrict__ sinks,
  24. const int * __restrict__ KV_max,
  25. float * __restrict__ dst,
  26. float2 * __restrict__ dst_meta,
  27. const float scale,
  28. const float max_bias,
  29. const float m0,
  30. const float m1,
  31. const uint32_t n_head_log2,
  32. const float logit_softcap,
  33. const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
  34. const int32_t nb01, const int32_t nb02, const int32_t nb03,
  35. const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
  36. const int32_t nb11, const int32_t nb12, const int64_t nb13,
  37. const int32_t nb21, const int32_t nb22, const int64_t nb23,
  38. const int32_t ne31, const int32_t ne32, const int32_t ne33,
  39. const int32_t nb31, const int32_t nb32, const int64_t nb33) {
  40. #ifdef FLASH_ATTN_AVAILABLE
  41. // Skip unused kernel variants for faster compilation:
  42. if (use_logit_softcap && !(D == 128 || D == 256)) {
  43. GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
  44. max_bias, m0, m1, n_head_log2, logit_softcap,
  45. ne00, ne01, ne02, ne03,
  46. nb01, nb02, nb03,
  47. ne10, ne11, ne12, ne13,
  48. nb11, nb12, nb13,
  49. nb21, nb22, nb23,
  50. ne31, ne32, ne33,
  51. nb31, nb32, nb33);
  52. NO_DEVICE_CODE;
  53. return;
  54. }
  55. //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
  56. constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
  57. constexpr int cpy_ne = cpy_nb / 4;
  58. #ifdef GGML_USE_HIP
  59. #ifdef RDNA
  60. constexpr int nthreads_KQ_q = 2;
  61. #else
  62. constexpr int nthreads_KQ_q = 4;
  63. #endif // RDNA
  64. constexpr int nthreads_V_q = (D/4 < 32 ? D/4 : 32);
  65. #else
  66. constexpr int nthreads_KQ_q = (D/4 < 32 ? D/4 : 32);
  67. constexpr int nthreads_V_q = (D/4 < 32 ? D/4 : 32);
  68. #endif // GGML_USE_HIP
  69. constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device();
  70. constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q;
  71. constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q;
  72. static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K");
  73. static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V");
  74. constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4;
  75. constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V;
  76. constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
  77. constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
  78. #ifdef FAST_FP16_AVAILABLE
  79. constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
  80. #else
  81. constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
  82. #endif // FAST_FP16_AVAILABLE
  83. const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
  84. const int sequence = blockIdx.z / ne02;
  85. const int head = blockIdx.z - sequence*ne02;
  86. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
  87. Q += nb03*sequence + nb02* head + nb01*ic0;
  88. K += nb13*sequence + nb12*(head / gqa_ratio);
  89. V += nb23*sequence + nb22*(head / gqa_ratio);
  90. const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
  91. const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
  92. static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
  93. constexpr int nwarps = nthreads / WARP_SIZE;
  94. const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
  95. __builtin_assume(tid < nthreads);
  96. constexpr int ne_KQ = ncols*D;
  97. constexpr int ne_combine = nwarps*V_cols_per_iter*D;
  98. #ifdef FAST_FP16_AVAILABLE
  99. half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
  100. __shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
  101. #else
  102. float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
  103. __shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
  104. #endif // FAST_FP16_AVAILABLE
  105. float KQ_max[ncols];
  106. float KQ_sum[ncols];
  107. #pragma unroll
  108. for (int j = 0; j < ncols; ++j) {
  109. KQ_max[j] = -FLT_MAX/2.0f;
  110. KQ_sum[j] = 0.0f;
  111. }
  112. // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
  113. #ifdef FAST_FP16_AVAILABLE
  114. half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
  115. #else
  116. float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
  117. #endif // FAST_FP16_AVAILABLE
  118. int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
  119. float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
  120. if constexpr (Q_q8_1) {
  121. #pragma unroll
  122. for (int j0 = 0; j0 < ncols; j0 += nwarps) {
  123. const int j = j0 + threadIdx.y;
  124. if (j0 + nwarps > ncols && j >= ncols) {
  125. break;
  126. }
  127. // Reuse KQ as temporary storage for converting Q to q8_1:
  128. int * tmp_q_i32 = (int *) &KQ[j*D];
  129. float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
  130. // Set memory to zero if out of bounds:
  131. if (ncols > 1 && ic0 + j >= ne01) {
  132. #pragma unroll
  133. for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
  134. const int i = i0 + threadIdx.x;
  135. if (i0 + WARP_SIZE <= D/sizeof(int) || i < D/sizeof(int)) {
  136. tmp_q_i32[i] = 0;
  137. }
  138. }
  139. if (threadIdx.x < D/QK8_1) {
  140. tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f);
  141. }
  142. } else {
  143. const float * Q_f = (const float *) (Q + j*nb01);
  144. constexpr int nthreads_quantize = D/sizeof(int) < WARP_SIZE ? D/sizeof(int) : WARP_SIZE;
  145. #pragma unroll
  146. for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_quantize) {
  147. quantize_q8_1_to_shared<float2, nthreads_quantize>
  148. (Q_f + i0*sizeof(int), scale, tmp_q_i32 + i0, tmp_q_ds + i0/QI8_1);
  149. }
  150. }
  151. }
  152. __syncthreads();
  153. #pragma unroll
  154. for (int j = 0; j < ncols; ++j) {
  155. int * tmp_q_i32 = (int *) &KQ[j*D];
  156. float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
  157. #pragma unroll
  158. for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_KQ) {
  159. const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ);
  160. Q_i32[j][i0/nthreads_KQ] = tmp_q_i32[i];
  161. Q_ds[j][i0/nthreads_KQ] = tmp_q_ds[i/QI8_1];
  162. }
  163. }
  164. __syncthreads();
  165. } else {
  166. #ifdef FAST_FP16_AVAILABLE
  167. const half2 scale_h2 = make_half2(scale, scale);
  168. #pragma unroll
  169. for (int j = 0; j < ncols; ++j) {
  170. const float2 * Q_j = (const float2 *) (Q + j*nb01);
  171. #pragma unroll
  172. for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
  173. const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
  174. float2 tmp[cpy_ne] = {{0.0f, 0.0f}};
  175. if (ncols == 1 || ic0 + j < ne01) {
  176. ggml_cuda_memcpy_1<cpy_nb>(tmp, &Q_j[i]);
  177. ggml_cuda_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);
  178. }
  179. #pragma unroll
  180. for (int i1 = 0; i1 < cpy_ne; ++i1) {
  181. Q_reg[j][i0/nthreads_KQ + i1] = make_half2(tmp[i1].x, tmp[i1].y);
  182. }
  183. }
  184. #pragma unroll
  185. for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
  186. Q_reg[j][k] *= scale_h2;
  187. }
  188. }
  189. #else
  190. #pragma unroll
  191. for (int j = 0; j < ncols; ++j) {
  192. const float2 * Q_j = (const float2 *) (Q + j*nb01);
  193. #pragma unroll
  194. for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
  195. const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
  196. if (ncols == 1 || ic0 + j < ne01) {
  197. ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]);
  198. ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]);
  199. }
  200. }
  201. #pragma unroll
  202. for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
  203. Q_reg[j][k].x *= scale;
  204. Q_reg[j][k].y *= scale;
  205. }
  206. }
  207. #endif // FAST_FP16_AVAILABLE
  208. }
  209. const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
  210. K += blockIdx.y*nthreads * nb11;
  211. V += blockIdx.y*nthreads * nb21;
  212. maskh += blockIdx.y*nthreads;
  213. for (int k_VKQ_0 = blockIdx.y*nthreads; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nthreads,
  214. // Increment pointers after each loop:
  215. K += gridDim.y*nthreads*nb11, V += gridDim.y*nthreads*nb21, maskh += gridDim.y*nthreads) {
  216. // Calculate KQ tile and keep track of new maximum KQ values:
  217. float KQ_reg[ncols]; // KQ in registers.
  218. float KQ_max_new[ncols];
  219. #pragma unroll
  220. for (int j = 0; j < ncols; ++j) {
  221. KQ_max_new[j] = KQ_max[j];
  222. }
  223. #pragma unroll
  224. for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) {
  225. const int i_KQ = threadIdx.y*WARP_SIZE + (nthreads_KQ == WARP_SIZE ? 0 : (threadIdx.x & ~(nthreads_KQ-1))) + i_KQ_0;
  226. #pragma unroll
  227. for (int j = 0; j < ncols; ++j) {
  228. float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]);
  229. sum = warp_reduce_sum<nthreads_KQ>(sum);
  230. if (use_logit_softcap) {
  231. sum = logit_softcap*tanhf(sum);
  232. }
  233. if (mask) {
  234. sum += slope*__half2float(maskh[j*ne11 + i_KQ]);
  235. }
  236. KQ_max_new[j] = fmaxf(KQ_max_new[j], sum);
  237. if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == i_KQ_0) {
  238. KQ_reg[j] = sum;
  239. }
  240. }
  241. }
  242. #pragma unroll
  243. for (int j = 0; j < ncols; ++j) {
  244. #pragma unroll
  245. for (int offset = nthreads_KQ; offset < WARP_SIZE; offset <<= 1) {
  246. KQ_max_new[j] = fmaxf(KQ_max_new[j], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[j], offset, WARP_SIZE));
  247. }
  248. const float KQ_max_scale = expf(KQ_max[j] - KQ_max_new[j]);
  249. KQ_max[j] = KQ_max_new[j];
  250. KQ_reg[j] = expf(KQ_reg[j] - KQ_max[j]);
  251. KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
  252. KQ[j*nthreads + tid] = KQ_reg[j];
  253. #ifdef FAST_FP16_AVAILABLE
  254. const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
  255. #pragma unroll
  256. for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
  257. VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;
  258. }
  259. #else
  260. #pragma unroll
  261. for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
  262. VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
  263. VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
  264. }
  265. #endif // FAST_FP16_AVAILABLE
  266. }
  267. #ifndef GGML_USE_HIP
  268. __syncwarp();
  269. #endif // GGML_USE_HIP
  270. #pragma unroll
  271. for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
  272. const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);
  273. #ifdef FAST_FP16_AVAILABLE
  274. half2 KQ_k[ncols];
  275. #pragma unroll
  276. for (int j = 0; j < ncols; ++j) {
  277. KQ_k[j] = __half2half2(KQ[j*nthreads + k]);
  278. }
  279. #pragma unroll
  280. for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
  281. half2 tmp[V_rows_per_thread/2];
  282. dequantize_V(V + k*nb21, tmp,
  283. 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
  284. #pragma unroll
  285. for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
  286. #pragma unroll
  287. for (int j = 0; j < ncols; ++j) {
  288. VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1] += tmp[i_VKQ_1]*KQ_k[j];
  289. }
  290. }
  291. }
  292. #else
  293. float KQ_k[ncols];
  294. #pragma unroll
  295. for (int j = 0; j < ncols; ++j) {
  296. KQ_k[j] = KQ[j*nthreads + k];
  297. }
  298. #pragma unroll
  299. for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
  300. float2 tmp[V_rows_per_thread/2];
  301. dequantize_V(V + k*nb21, tmp,
  302. 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
  303. #pragma unroll
  304. for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
  305. #pragma unroll
  306. for (int j = 0; j < ncols; ++j) {
  307. VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += tmp[i_VKQ_1].x*KQ_k[j];
  308. VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += tmp[i_VKQ_1].y*KQ_k[j];
  309. }
  310. }
  311. }
  312. #endif // FAST_FP16_AVAILABLE
  313. }
  314. }
  315. if (sinks && blockIdx.y == 0) {
  316. const float sink = ((const float *) sinks)[head];
  317. #pragma unroll
  318. for (int j0 = 0; j0 < ncols; j0 += nwarps) {
  319. const int j = j0 + threadIdx.y;
  320. if (j0 + nwarps > ncols && j >= ncols) {
  321. break;
  322. }
  323. const float kqmax_new_j = fmaxf(sink, KQ_max[j]);
  324. const float KQ_max_scale = expf(KQ_max[j] - kqmax_new_j);
  325. KQ_max[j] = kqmax_new_j;
  326. KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f);
  327. #ifdef FAST_FP16_AVAILABLE
  328. const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
  329. #pragma unroll
  330. for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
  331. VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;
  332. }
  333. #else
  334. #pragma unroll
  335. for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
  336. VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
  337. VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
  338. }
  339. #endif // FAST_FP16_AVAILABLE
  340. }
  341. }
  342. __shared__ float KQ_max_shared[ncols][WARP_SIZE];
  343. __shared__ float KQ_sum_shared[ncols][WARP_SIZE];
  344. #pragma unroll
  345. for (int j = 0; j < ncols; ++j) {
  346. if (threadIdx.y == 0) {
  347. KQ_max_shared[j][threadIdx.x] = -FLT_MAX/2.0f;
  348. KQ_sum_shared[j][threadIdx.x] = 0.0f;
  349. }
  350. }
  351. __syncthreads();
  352. #pragma unroll
  353. for (int j = 0; j < ncols; ++j) {
  354. if (threadIdx.x == 0) {
  355. KQ_max_shared[j][threadIdx.y] = KQ_max[j];
  356. }
  357. }
  358. __syncthreads();
  359. #pragma unroll
  360. for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
  361. if (ncols > 1 && ic0 + j_VKQ >= ne01) {
  362. break;
  363. }
  364. float kqmax_new = KQ_max_shared[j_VKQ][threadIdx.x];
  365. kqmax_new = warp_reduce_max(kqmax_new);
  366. const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
  367. KQ_max[j_VKQ] = kqmax_new;
  368. #ifdef FAST_FP16_AVAILABLE
  369. half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
  370. + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
  371. const half2 kqmax_scale_h2 = make_half2(kqmax_scale, kqmax_scale);
  372. #pragma unroll
  373. for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
  374. VKQ[j_VKQ][i_VKQ_0/nthreads_V] *= kqmax_scale_h2;
  375. }
  376. #pragma unroll
  377. for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
  378. const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2);
  379. ggml_cuda_memcpy_1<V_rows_per_thread*sizeof(half)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
  380. }
  381. #else
  382. float2 * VKQ_tmp = (float2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
  383. + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
  384. #pragma unroll
  385. for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
  386. VKQ[j_VKQ][i_VKQ_0/nthreads_V].x *= kqmax_scale;
  387. VKQ[j_VKQ][i_VKQ_0/nthreads_V].y *= kqmax_scale;
  388. }
  389. #pragma unroll
  390. for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
  391. const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2);
  392. ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
  393. ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
  394. }
  395. #endif // FAST_FP16_AVAILABLE
  396. KQ_sum[j_VKQ] *= kqmax_scale;
  397. KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
  398. if (threadIdx.x == 0) {
  399. KQ_sum_shared[j_VKQ][threadIdx.y] = KQ_sum[j_VKQ];
  400. }
  401. __syncthreads();
  402. if (nthreads <= D || tid < D) {
  403. KQ_sum[j_VKQ] = KQ_sum_shared[j_VKQ][threadIdx.x];
  404. KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
  405. #pragma unroll
  406. for (int i0 = 0; i0 < D; i0 += nthreads) {
  407. float dst_val = 0;
  408. #pragma unroll
  409. for (int w = 0; w < nwarps; ++w) {
  410. #pragma unroll
  411. for (int v = 0; v < V_cols_per_iter; ++v) {
  412. dst_val += float(KQ[w*V_cols_per_iter*D + v*D + i0 + tid]);
  413. }
  414. }
  415. if (gridDim.y == 1) {
  416. dst_val /= KQ_sum[j_VKQ];
  417. }
  418. dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val;
  419. }
  420. }
  421. if (j_VKQ < ncols-1) {
  422. __syncthreads();
  423. }
  424. }
  425. if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < ne01)) {
  426. dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]);
  427. }
  428. #else
  429. GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
  430. max_bias, m0, m1, n_head_log2, logit_softcap,
  431. ne00, ne01, ne02, ne03,
  432. nb01, nb02, nb03,
  433. ne10, ne11, ne12, ne13,
  434. nb11, nb12, nb13,
  435. nb21, nb22, nb23,
  436. ne31, ne32, ne33,
  437. nb31, nb32, nb33);
  438. NO_DEVICE_CODE;
  439. #endif // FLASH_ATTN_AVAILABLE
  440. }
  441. #ifdef __clang__
  442. #pragma clang diagnostic pop
  443. #endif // __clang__
  444. template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
  445. void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  446. const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
  447. const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc);
  448. const int nwarps = nthreads / WARP_SIZE;
  449. fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>;
  450. constexpr bool need_f16_K = false;
  451. constexpr bool need_f16_V = false;
  452. constexpr size_t nbytes_shared = 0;
  453. launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
  454. }
  455. template <int D, ggml_type type_K, ggml_type type_V>
  456. void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  457. const ggml_tensor * KQV = dst;
  458. const ggml_tensor * Q = dst->src[0];
  459. const ggml_tensor * K = dst->src[1];
  460. const ggml_tensor * V = dst->src[2];
  461. GGML_ASSERT(K->type == type_K);
  462. GGML_ASSERT(V->type == type_V);
  463. float logit_softcap;
  464. memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
  465. const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
  466. if (Q->ne[1] == 1) {
  467. constexpr int cols_per_block = 1;
  468. if (logit_softcap == 0.0f) {
  469. constexpr bool use_logit_softcap = false;
  470. ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
  471. } else {
  472. constexpr bool use_logit_softcap = true;
  473. ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
  474. }
  475. return;
  476. }
  477. constexpr int cols_per_block = 2;
  478. if (logit_softcap == 0.0f) {
  479. constexpr bool use_logit_softcap = false;
  480. ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
  481. } else {
  482. constexpr bool use_logit_softcap = true;
  483. ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
  484. }
  485. }
  486. #define DECL_FATTN_VEC_CASE(D, type_K, type_V) \
  487. template void ggml_cuda_flash_attn_ext_vec_case \
  488. <D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
  489. #define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \
  490. extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \
  491. extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \
  492. extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \
  493. extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \
  494. extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \
  495. extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \
  496. EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16)
  497. EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0)
  498. EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1)
  499. EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0)
  500. EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1)
  501. EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0)
  502. EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16)
  503. EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0)
  504. EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1)
  505. EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0)
  506. EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1)
  507. EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0)
  508. EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16)
  509. EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0)
  510. EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1)
  511. EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
  512. EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
  513. EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)